1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_ 18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_ 19 20 #include <memory> 21 #include <vector> 22 #include <string> 23 #include <unordered_map> 24 25 #include "transform/graph_ir/op_adapter_util.h" 26 #include "utils/utils.h" 27 namespace mindspore { 28 namespace transform { 29 class OpAdapterImpl { 30 public: OpAdapterImpl(const std::unordered_map<int,InputDesc> & input_map,const std::unordered_map<int,DynInputDesc> & dyn_input_map,const std::unordered_map<int,OutputDesc> & output_map,const std::unordered_map<int,DynOutputDesc> & dyn_output_map,const std::unordered_map<int,DynSubGraphDesc> & dyn_subgraph_map,const std::unordered_map<std::string,AttrDesc> & attr_map,const std::unordered_map<std::string,int> & enum_map,const std::unordered_map<unsigned int,AttrDesc> & input_attr_map,std::unordered_map<std::string,std::unordered_map<int,std::string>> * cus_input_map,std::unordered_map<std::string,std::unordered_map<int,std::string>> * cus_output_map,std::unordered_map<std::string,ValuePtr> * extra_attr,std::unordered_map<std::string,int> * name_counts,BaseOpAdapter * adpt)31 OpAdapterImpl(const std::unordered_map<int, InputDesc> &input_map, 32 const std::unordered_map<int, DynInputDesc> &dyn_input_map, 33 const std::unordered_map<int, OutputDesc> &output_map, 34 const std::unordered_map<int, DynOutputDesc> &dyn_output_map, 35 const std::unordered_map<int, DynSubGraphDesc> &dyn_subgraph_map, 36 const std::unordered_map<std::string, AttrDesc> &attr_map, 37 const std::unordered_map<std::string, int> &enum_map, 38 const std::unordered_map<unsigned int, AttrDesc> &input_attr_map, 39 std::unordered_map<std::string, std::unordered_map<int, std::string>> *cus_input_map, 40 std::unordered_map<std::string, std::unordered_map<int, std::string>> *cus_output_map, 41 std::unordered_map<std::string, ValuePtr> *extra_attr, 42 std::unordered_map<std::string, int> *name_counts, BaseOpAdapter *adpt) 43 : input_map_(input_map), 44 dyn_input_map_(dyn_input_map), 45 output_map_(output_map), 46 dyn_output_map_(dyn_output_map), 47 dyn_subgraph_map_(dyn_subgraph_map), 48 attr_map_(attr_map), 49 enum_map_(enum_map), 50 input_attr_map_(input_attr_map), 51 cus_input_map_(cus_input_map), 52 cus_output_map_(cus_output_map), 53 extra_attr_(extra_attr), 54 name_counts_(name_counts), 55 adpt_(adpt) { 56 MS_EXCEPTION_IF_NULL(cus_input_map_); 57 MS_EXCEPTION_IF_NULL(cus_output_map_); 58 MS_EXCEPTION_IF_NULL(extra_attr_); 59 MS_EXCEPTION_IF_NULL(name_counts_); 60 MS_EXCEPTION_IF_NULL(adpt_); 61 } ~OpAdapterImpl()62 ~OpAdapterImpl() {} 63 bool IsCustomOp(const OperatorPtr &op); 64 Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim); 65 Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim); 66 OperatorPtr GenerateCustomOp(const AnfNodePtr anf); 67 Status SetOpSubgraphFunc(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<DfGraph>> &branches); 68 Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input); 69 Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input); 70 int setInput(const OperatorPtr &op, int index, const OperatorPtr &input); 71 Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle); 72 Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle); 73 int setInput(const OperatorPtr &op, int index, const OutHandler &handle); 74 int setInput(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<OutHandler>> &handler_vec); 75 OutHandler getOutput(const OperatorPtr &op, int index); 76 OutHandler getCustomOutput(const OperatorPtr &op, int index); 77 OutHandler getNormalOutput(const OperatorPtr &op, int index); 78 Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 79 const std::string &format); 80 size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op); 81 std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, 82 const std::string &format); 83 Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 84 const std::string &format); 85 std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format); 86 void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format); 87 void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format); 88 void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node); 89 void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 90 const AnfNodePtr &node); 91 int setAttr(const OperatorPtr &op, const std::string &attr_key, const ValuePtr &attr_value); 92 int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim); 93 int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim); 94 int setAttr(const OperatorPtr &op, const PrimitivePtr &prim); 95 int setAttr(const OperatorPtr &op, const AnfNodePtr &node); 96 97 private: 98 const std::unordered_map<int, InputDesc> &input_map_; 99 const std::unordered_map<int, DynInputDesc> &dyn_input_map_; 100 const std::unordered_map<int, OutputDesc> &output_map_; 101 const std::unordered_map<int, DynOutputDesc> &dyn_output_map_; 102 const std::unordered_map<int, DynSubGraphDesc> &dyn_subgraph_map_; 103 const std::unordered_map<std::string, AttrDesc> &attr_map_; 104 const std::unordered_map<std::string, int> &enum_map_; 105 const std::unordered_map<unsigned int, AttrDesc> &input_attr_map_; 106 std::unordered_map<std::string, std::unordered_map<int, std::string>> *const cus_input_map_; 107 std::unordered_map<std::string, std::unordered_map<int, std::string>> *const cus_output_map_; 108 std::unordered_map<std::string, ValuePtr> *const extra_attr_; 109 std::unordered_map<std::string, int> *const name_counts_; 110 BaseOpAdapter *const adpt_; 111 }; 112 113 template <typename T> 114 class OpAdapter : public BaseOpAdapter { 115 public: 116 using OpType = T; OpAdapter()117 OpAdapter() 118 : impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_, 119 dyn_subgraph_map_, attr_map_, enum_map_, input_attr_map_, &cus_input_map_, 120 &cus_output_map_, &extra_attr_, &name_counts_, this)) { 121 MS_EXCEPTION_IF_NULL(impl_); 122 } OpAdapter(const ExtraAttr & extra_attr)123 explicit OpAdapter(const ExtraAttr &extra_attr) 124 : extra_attr_(extra_attr), 125 impl_(std::make_shared<OpAdapterImpl>(input_map_, dyn_input_map_, output_map_, dyn_output_map_, 126 dyn_subgraph_map_, attr_map_, enum_map_, input_attr_map_, &cus_input_map_, 127 &cus_output_map_, &extra_attr_, &name_counts_, this)) { 128 MS_EXCEPTION_IF_NULL(impl_); 129 } ~OpAdapter()130 ~OpAdapter() override {} 131 IsCustomOp(const OperatorPtr & op)132 bool IsCustomOp(const OperatorPtr &op) { return impl_->IsCustomOp(op); } 133 GenerateCustomOpInputMap(const CusOperatorPtr & op,const PrimitivePtr & prim)134 Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { 135 return impl_->GenerateCustomOpInputMap(op, prim); 136 } 137 GenerateCustomOpOutputMap(const CusOperatorPtr & op,const PrimitivePtr & prim)138 Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { 139 return impl_->GenerateCustomOpOutputMap(op, prim); 140 } 141 142 // Convert ME UserCustom AnfNode to GE CustomOp. And set it's attrs. GenerateCustomOp(const AnfNodePtr anf)143 OperatorPtr GenerateCustomOp(const AnfNodePtr anf) { return impl_->GenerateCustomOp(anf); } 144 GenerateNormalOp(const AnfNodePtr & anf)145 OperatorPtr GenerateNormalOp(const AnfNodePtr &anf) { 146 OperatorPtr op = nullptr; 147 // There are duplicate names in ANF graph, do not assign ANF node name to GE 148 // GE will generate unique name automatically 149 if (anf != nullptr && anf->fullname_with_scope() != "") { 150 MS_LOG(DEBUG) << anf->fullname_with_scope(); 151 op = std::make_shared<OpType>(anf->fullname_with_scope()); 152 } else { 153 MS_LOG(DEBUG) << "no fullname_with_scope"; 154 op = std::make_shared<OpType>(); 155 } 156 157 // set dynamic output num if op use DYNAMIC_OUTPUT 158 if ((op != nullptr) && (!dyn_output_map_.empty()) && (anf != nullptr)) { 159 TypePtr type = anf->Type(); 160 if (type == nullptr) { 161 MS_LOG(EXCEPTION) << "Dynamic output node:" << op->GetName() << "'s Type is a nullptr!"; 162 } 163 size_t num = type->isa<Tuple>() ? (type->cast<std::shared_ptr<Tuple>>()->size()) : 1; 164 MS_LOG(INFO) << "create_dyn_output for node:" << anf->ToString() << ", type:" << type->ToString() 165 << ", num:" << num; 166 dyn_output_map_.begin()->second.create_dyn_output(op, static_cast<unsigned int>(num)); 167 } 168 return op; 169 } 170 generate(const AnfNodePtr & anf)171 OperatorPtr generate(const AnfNodePtr &anf) override { 172 OperatorPtr op = nullptr; 173 if (IsCustomCNode(anf)) { 174 op = GenerateCustomOp(anf); 175 } else { 176 op = GenerateNormalOp(anf); 177 } 178 if (op == nullptr) { 179 MS_LOG(EXCEPTION) << "Can not generate op for " << anf->fullname_with_scope(); 180 } 181 return op; 182 } 183 generate(const std::string & op_name)184 OperatorPtr generate(const std::string &op_name) override { return std::make_shared<OpType>(op_name); } 185 getInputMap()186 const std::unordered_map<int, InputDesc> &getInputMap() override { return input_map_; } getInputAttrMap()187 const std::unordered_map<unsigned int, AttrDesc> &getInputAttrMap() override { return input_attr_map_; } getDynInputMap()188 const std::unordered_map<int, DynInputDesc> &getDynInputMap() override { return dyn_input_map_; } getOutputMap()189 const std::unordered_map<int, OutputDesc> &getOutputMap() override { return output_map_; } getDynSubgraphMap()190 const std::unordered_map<int, DynSubGraphDesc> &getDynSubgraphMap() override { return dyn_subgraph_map_; } 191 SetOpSubgraphFunc(const OperatorPtr & op,int index,std::shared_ptr<std::vector<DfGraph>> branches)192 Status SetOpSubgraphFunc(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) { 193 return impl_->SetOpSubgraphFunc(op, index, branches); 194 } 195 setSubgraph(const OperatorPtr & op,int index,std::shared_ptr<std::vector<DfGraph>> branches)196 int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr<std::vector<DfGraph>> branches) override { 197 return static_cast<int>(SetOpSubgraphFunc(op, index, branches)); 198 } 199 SetCustomOpInput(const CusOperatorPtr & op,int index,const OperatorPtr & input)200 Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { 201 return impl_->SetCustomOpInput(op, index, input); 202 } 203 SetNormalOpInput(const OperatorPtr & op,int index,const OperatorPtr & input)204 Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) { 205 return impl_->SetNormalOpInput(op, index, input); 206 } 207 setInput(const OperatorPtr & op,int index,const OperatorPtr & input)208 int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) override { 209 return impl_->setInput(op, index, input); 210 } 211 SetCustomOpInput(const CusOperatorPtr & op,int index,const OutHandler & handle)212 Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) { 213 return impl_->SetCustomOpInput(op, index, handle); 214 } 215 SetNormalOpInput(const OperatorPtr & op,int index,const OutHandler & handle)216 Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) { 217 return impl_->SetNormalOpInput(op, index, handle); 218 } 219 setInput(const OperatorPtr & op,int index,const OutHandler & handle)220 int setInput(const OperatorPtr &op, int index, const OutHandler &handle) override { 221 return impl_->setInput(op, index, handle); 222 } 223 setInput(const OperatorPtr & op,int index,const std::shared_ptr<std::vector<OutHandler>> & handler_vec)224 int setInput(const OperatorPtr &op, int index, const std::shared_ptr<std::vector<OutHandler>> &handler_vec) override { 225 return impl_->setInput(op, index, handler_vec); 226 } 227 getOutput(const OperatorPtr & op,int index)228 OutHandler getOutput(const OperatorPtr &op, int index) override { return impl_->getOutput(op, index); } 229 getCustomOutput(const OperatorPtr & op,int index)230 OutHandler getCustomOutput(const OperatorPtr &op, int index) { return impl_->getCustomOutput(op, index); } 231 getNormalOutput(const OperatorPtr & op,int index)232 OutHandler getNormalOutput(const OperatorPtr &op, int index) { return impl_->getNormalOutput(op, index); } 233 UpdateSingleOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const std::string & format)234 Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 235 const std::string &format) { 236 return impl_->UpdateSingleOutputDesc(op, shp, type, format); 237 } 238 GetCustomOpOutputSize(const CusOperatorPtr & cus_op)239 size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { return impl_->GetCustomOpOutputSize(cus_op); } 240 CreateOutputDesc(const abstract::ShapePtr & shape_ptr,const TypePtr & type,const std::string & format)241 std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, 242 const std::string &format) { 243 return impl_->CreateOutputDesc(shape_ptr, type, format); 244 } 245 UpdateMultiOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const std::string & format)246 Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 247 const std::string &format) { 248 return impl_->UpdateMultiOutputDesc(op, shp, type, format); 249 } 250 CreateNodeDesc(const AnfNodePtr & node,const std::string & format)251 std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format) { 252 return impl_->CreateNodeDesc(node, format); 253 } 254 UpdateNormalOpInputDesc(const OperatorPtr & op,const AnfNodePtr node,const std::string format)255 void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node, const std::string format) { 256 return impl_->UpdateNormalOpInputDesc(op, node, format); 257 } 258 UpdateCustomOpInputDesc(const CusOperatorPtr & op,const AnfNodePtr & node,const std::string format)259 void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format) { 260 return impl_->UpdateCustomOpInputDesc(op, node, format); 261 } 262 updateInputDesc(const OperatorPtr & op,const AnfNodePtr & node)263 void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { impl_->updateInputDesc(op, node); } 264 updateOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const AnfNodePtr & node)265 void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, 266 const AnfNodePtr &node) override { 267 impl_->updateOutputDesc(op, shp, type, node); 268 } 269 setAttr(const OperatorPtr & op,const std::string & attrKey,const ValuePtr & attrValue)270 int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) override { 271 return impl_->setAttr(op, attrKey, attrValue); 272 } 273 SetCustomOpAttr(const CusOperatorPtr & op,const PrimitivePtr & prim)274 int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) { return impl_->SetCustomOpAttr(op, prim); } 275 SetNormalOpAttr(const OperatorPtr & op,const PrimitivePtr & prim)276 int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) { return impl_->SetNormalOpAttr(op, prim); } 277 setAttr(const OperatorPtr & op,const PrimitivePtr & prim)278 int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) override { return impl_->setAttr(op, prim); } 279 setAttr(const OperatorPtr & op,const AnfNodePtr & node)280 int setAttr(const OperatorPtr &op, const AnfNodePtr &node) override { return impl_->setAttr(op, node); } 281 GetExtraAttr()282 std::unordered_map<std::string, ValuePtr> GetExtraAttr() override { return extra_attr_; } 283 284 private: 285 template <typename S> ConvertAny(const ValuePtr & value,const AnyTraits<S> &)286 static S ConvertAny(const ValuePtr &value, const AnyTraits<S> &) { 287 return GetValue<S>(value); 288 } 289 290 // specialization for reverse bool ConvertAny(const ValuePtr & value,const AnyTraits<bool> &,bool reverse)291 static bool ConvertAny(const ValuePtr &value, const AnyTraits<bool> &, bool reverse) { 292 return reverse != GetValue<bool>(value); 293 } 294 295 template <typename P, typename Q> ConvertAny(const ValuePtr & value,const AnyTraits<P> & traits_from,const AnyTraits<Q> & traits_to)296 static Q ConvertAny(const ValuePtr &value, const AnyTraits<P> &traits_from, const AnyTraits<Q> &traits_to) { 297 return ConvertAnyUtil(value, traits_from, traits_to); 298 } 299 300 // specialization for tensor ConvertAny(const ValuePtr & value,const AnyTraits<mindspore::tensor::Tensor> & traits)301 static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits<mindspore::tensor::Tensor> &traits) { 302 // To-DO the format may read from ME tensor 303 return ConvertAnyUtil(value, traits); 304 } 305 306 // specialization for int ConvertAny(const ValuePtr & value,const AnyTraits<int64_t>)307 static int64_t ConvertAny(const ValuePtr &value, const AnyTraits<int64_t>) { 308 return static_cast<int64_t>(GetValue<int64_t>(value)); 309 } 310 311 // specialization for int or tuple broadcast to Vector ConvertAny(const ValuePtr & value,const std::string & name,const AnyTraits<std::vector<int64_t>> anyTraitsInt)312 static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &name, 313 const AnyTraits<std::vector<int64_t>> anyTraitsInt) { 314 return ConvertAnyUtil(value, name, anyTraitsInt); 315 } 316 ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<std::vector<int64_t>>>)317 static std::vector<std::vector<int64_t>> ConvertAny(const ValuePtr &value, 318 const AnyTraits<std::vector<std::vector<int64_t>>>) { 319 MS_EXCEPTION_IF_NULL(value); 320 MS_LOG(INFO) << "Value: " << value->type_name(); 321 std::vector<std::vector<int64_t>> list; 322 if (!value->isa<ValueTuple>()) { 323 MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got " << value->type_name(); 324 } 325 auto vec = value->cast<ValueTuplePtr>(); 326 MS_EXCEPTION_IF_NULL(vec); 327 for (auto &it : vec->value()) { 328 MS_EXCEPTION_IF_NULL(it); 329 if (!it->isa<ValueTuple>()) { 330 MS_LOG(EXCEPTION) << "It should be ValueTuple, but got " << it->type_name(); 331 } 332 auto sub_vector = it->cast<ValueTuplePtr>(); 333 std::vector<int64_t> sublist; 334 for (auto &item : sub_vector->value()) { 335 sublist.push_back(static_cast<int64_t>(GetValue<int64_t>(item))); 336 } 337 list.push_back(sublist); 338 } 339 return list; 340 } 341 ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<std::vector<int64_t>>>,const AnyTraits<std::vector<int64_t>>)342 static std::vector<int64_t> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<std::vector<int64_t>>>, 343 const AnyTraits<std::vector<int64_t>>) { 344 MS_EXCEPTION_IF_NULL(value); 345 MS_LOG(DEBUG) << "Value: " << value->type_name(); 346 if (!value->isa<ValueList>()) { 347 MS_LOG(EXCEPTION) << "Value should be ValueList, but got " << value->type_name(); 348 } 349 auto vec = value->cast<ValueListPtr>(); 350 std::vector<int64_t> list; 351 for (auto &it : vec->value()) { 352 MS_EXCEPTION_IF_NULL(it); 353 if (!it->isa<ValueList>()) { 354 MS_LOG(EXCEPTION) << "It should be ValueList, but got " << it->type_name(); 355 } 356 auto sub_vector = it->cast<ValueListPtr>(); 357 for (auto &item : sub_vector->value()) { 358 list.push_back(static_cast<int64_t>(GetValue<int64_t>(item))); 359 } 360 } 361 return list; 362 } 363 ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<int64_t>>,const AnyTraits<std::vector<int64_t>>)364 static std::vector<int64_t> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>, 365 const AnyTraits<std::vector<int64_t>>) { 366 MS_EXCEPTION_IF_NULL(value); 367 MS_LOG(INFO) << "Value: " << value->type_name(); 368 std::vector<int64_t> list; 369 if (value->isa<ValueSequeue>()) { 370 auto vec = value->cast<ValueSequeuePtr>(); 371 MS_EXCEPTION_IF_NULL(vec); 372 for (auto &it : vec->value()) { 373 list.push_back(static_cast<int64_t>(GetValue<int64_t>(it))); 374 } 375 return list; 376 } 377 if (value->isa<Scalar>()) { 378 list.push_back(static_cast<int64_t>(GetValue<int64_t>(value))); 379 return list; 380 } 381 MS_LOG(EXCEPTION) << "Value should be ValueTuple or Scalar, but got " << value->type_name(); 382 } 383 ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<int64_t>> anyTraitsVec,const AnyTraits<std::string> anyTraitsStr)384 static std::string ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<int64_t>> anyTraitsVec, 385 const AnyTraits<std::string> anyTraitsStr) { 386 return ConvertAnyUtil(value, anyTraitsVec, anyTraitsStr); 387 } 388 ConvertAny(const ValuePtr & value,const AnyTraits<std::vector<float>> anyTraitsVec,const AnyTraits<float> anyTraitsFlo)389 static std::vector<float> ConvertAny(const ValuePtr &value, const AnyTraits<std::vector<float>> anyTraitsVec, 390 const AnyTraits<float> anyTraitsFlo) { 391 return ConvertAnyUtil(value, anyTraitsVec, anyTraitsFlo); 392 } 393 ConvertAny(const ValuePtr & value,const std::string & format,const AnyTraits<std::vector<int64_t>> anyTraitsVec,const AnyTraits<int64_t> anyTraitsInt)394 static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &format, 395 const AnyTraits<std::vector<int64_t>> anyTraitsVec, 396 const AnyTraits<int64_t> anyTraitsInt) { 397 return ConvertAnyUtil(value, format, anyTraitsVec, anyTraitsInt); 398 } 399 400 // convert value list for value tuple to vector 401 template <typename P, typename Q> ConvertAny(const ValuePtr & value,const AnyTraits<P> & anyTraitsP,const AnyTraits<std::vector<Q>> anyTraitsQ)402 static std::vector<Q> ConvertAny(const ValuePtr &value, const AnyTraits<P> &anyTraitsP, 403 const AnyTraits<std::vector<Q>> anyTraitsQ) { 404 return ConvertAnyUtil(value, anyTraitsP, anyTraitsQ); 405 } 406 ConvertAny(const ValuePtr & value,const AnyTraits<GeEnum>)407 static int64_t ConvertAny(const ValuePtr &value, const AnyTraits<GeEnum>) { 408 auto name = GetValue<std::string>(value); 409 auto it = enum_map_.find(name); 410 int v = 0; 411 if (it != enum_map_.end()) { 412 v = it->second; 413 } 414 return v; 415 } 416 ConvertAny(const ValuePtr & value,const AnyTraits<GEType> anyTraitsGE)417 static GeDataType ConvertAny(const ValuePtr &value, const AnyTraits<GEType> anyTraitsGE) { 418 return ConvertAnyUtil(value, anyTraitsGE); 419 } 420 421 // convert any value to tensor ConvertAny(const ValuePtr & value,const AnyTraits<AnyValue> anyTraitsValue)422 static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits<AnyValue> anyTraitsValue) { 423 return ConvertAnyUtil(value, anyTraitsValue); 424 } 425 426 static const std::unordered_map<int, InputDesc> input_map_; 427 static const std::unordered_map<int, DynInputDesc> dyn_input_map_; 428 static const std::unordered_map<int, OutputDesc> output_map_; 429 static const std::unordered_map<int, DynOutputDesc> dyn_output_map_; 430 static const std::unordered_map<int, DynSubGraphDesc> dyn_subgraph_map_; 431 static const std::unordered_map<std::string, AttrDesc> attr_map_; 432 static const std::unordered_map<std::string, int> enum_map_; 433 // convert input from anf graph to Attr in Operators 434 static const std::unordered_map<unsigned int, AttrDesc> input_attr_map_; 435 static std::unordered_map<std::string, std::unordered_map<int, std::string>> cus_input_map_; 436 static std::unordered_map<std::string, std::unordered_map<int, std::string>> cus_output_map_; 437 std::unordered_map<std::string, ValuePtr> extra_attr_; 438 std::unordered_map<std::string, int> name_counts_; 439 const std::shared_ptr<OpAdapterImpl> impl_; 440 }; 441 442 template <typename T> 443 const std::unordered_map<int, InputDesc> OpAdapter<T>::input_map_; 444 template <typename T> 445 const std::unordered_map<int, DynInputDesc> OpAdapter<T>::dyn_input_map_; 446 template <typename T> 447 const std::unordered_map<int, OutputDesc> OpAdapter<T>::output_map_; 448 template <typename T> 449 const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_; 450 template <typename T> 451 const std::unordered_map<int, DynSubGraphDesc> OpAdapter<T>::dyn_subgraph_map_; 452 template <typename T> 453 const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_; 454 template <typename T> 455 const std::unordered_map<std::string, int> OpAdapter<T>::enum_map_; 456 template <typename T> 457 const std::unordered_map<unsigned int, AttrDesc> OpAdapter<T>::input_attr_map_; 458 template <typename T> 459 std::unordered_map<std::string, std::unordered_map<int, std::string>> OpAdapter<T>::cus_input_map_; 460 template <typename T> 461 std::unordered_map<std::string, std::unordered_map<int, std::string>> OpAdapter<T>::cus_output_map_; 462 463 // specialization for method 464 } // namespace transform 465 } // namespace mindspore 466 467 #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_H_ 468