• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "transform/graph_ir/op_adapter.h"
17 
18 #include <algorithm>
19 #include <map>
20 #include <string>
21 #include <unordered_set>
22 #include "utils/check_convert_utils.h"
23 #include "op_proto/inc/split_combination_ops.h"
24 #include "graph/operator_factory.h"
25 #include "include/common/utils/convert_utils.h"
26 #include "utils/anf_utils.h"
27 #include "include/common/utils/anfalgo.h"
28 
29 namespace mindspore {
30 namespace transform {
31 ge::graphStatus CustomAkgOpInferFunc(ge::Operator &op);
32 
33 ge::graphStatus CustomTbeAicpuOpInferFunc(ge::Operator &op);
34 
35 enum class CustomOpType { kUnKnown, kAkg, kTbe, kAiCpu };
36 
GetCustomOpTypeDetail(const PrimitivePtr & prim)37 CustomOpType GetCustomOpTypeDetail(const PrimitivePtr &prim) {
38   if (prim == nullptr) {
39     return CustomOpType::kUnKnown;
40   }
41   auto type = prim->GetAttr("type");
42   if (type != nullptr && GetValue<std::string>(type) == "GraphKernel") {
43     return CustomOpType::kAkg;
44   }
45   auto func_type = prim->GetAttr("func_type");
46   if (func_type != nullptr) {
47     auto func_type_value = GetValue<std::string>(func_type);
48     if (func_type_value == "tbe") {
49       return CustomOpType::kTbe;
50     } else if (func_type_value == "aicpu") {
51       return CustomOpType::kAiCpu;
52     }
53   }
54   return CustomOpType::kUnKnown;
55 }
56 
GetCustomOpInputNames(const PrimitivePtr & prim)57 ValuePtr GetCustomOpInputNames(const PrimitivePtr &prim) {
58   MS_EXCEPTION_IF_NULL(prim);
59   // MS Custom op 'input_names' include attr names, which is not expected
60   auto value = prim->GetAttr("pure_input_names");
61   if (value == nullptr) {
62     value = prim->GetAttr("input_names");
63   }
64   return value;
65 }
66 
GetCustomOpKernelAttrs(const PrimitivePtr & prim)67 std::vector<std::string> GetCustomOpKernelAttrs(const PrimitivePtr &prim) {
68   MS_EXCEPTION_IF_NULL(prim);
69   std::vector<std::string> res;
70   auto op_type = GetCustomOpTypeDetail(prim);
71   auto attr_names = prim->GetAttr("attr_names");
72   if (attr_names != nullptr) {
73     auto names = GetValue<std::vector<std::string>>(attr_names);
74     std::vector<std::string> optional_attrs;
75     auto attr_ptr = prim->GetAttr("missing_optional_attrs");
76     if (attr_ptr != nullptr) {
77       optional_attrs = GetValue<std::vector<std::string>>(attr_ptr);
78     }
79     for (const auto &name : names) {
80       if (!prim->HasAttr(name)) {
81         // optional attr can have no value, but required attr must have value
82         if (std::find(optional_attrs.begin(), optional_attrs.end(), name) == std::end(optional_attrs)) {
83           MS_LOG(ERROR) << "Custom op attr '" << name << "' value not set";
84         }
85       } else {
86         if (op_type == CustomOpType::kAiCpu && name == "cust_aicpu") {
87           continue;
88         }
89         res.push_back(name);
90       }
91     }
92   }
93   return res;
94 }
95 
RegisterCustomOp(const PrimitivePtr & prim,const std::string & op_type,const std::vector<std::string> & attr_names,bool is_akg)96 void RegisterCustomOp(const PrimitivePtr &prim, const std::string &op_type, const std::vector<std::string> &attr_names,
97                       bool is_akg) {
98   if (ge::OperatorFactory::IsExistOp(op_type)) {
99     return;
100   }
101   MS_EXCEPTION_IF_NULL(prim);
102   auto input_names_v = GetCustomOpInputNames(prim);
103   MS_EXCEPTION_IF_NULL(input_names_v);
104   auto input_names = GetValue<std::vector<std::string>>(input_names_v);
105   auto output_names_v = prim->GetAttr("output_names");
106   MS_EXCEPTION_IF_NULL(output_names_v);
107   auto output_names = GetValue<std::vector<std::string>>(output_names_v);
108   // Register op create function, which describes how to create a custom op
109   ::ge::OperatorCreatorRegister op_create_reg(
110     op_type, [op_type, input_names, output_names, attr_names, is_akg](const std::string &name) {
111       auto op = ge::CustomOperator(name, op_type);
112       for (const auto &in_name : input_names) {
113         op.CustomInputRegister(in_name);
114       }
115       for (const auto &out_name : output_names) {
116         op.CustomOutputRegister(out_name);
117       }
118       for (const auto &attr_name : attr_names) {
119         op.CustomRequiredAttrRegister(attr_name);
120       }
121       if (is_akg) {
122         op.CustomInferFuncRegister(CustomAkgOpInferFunc);
123       } else {
124         op.CustomInferFuncRegister(CustomTbeAicpuOpInferFunc);
125       }
126       return op;
127     });
128   // Register op infer shape function
129   if (is_akg) {
130     ::ge::InferShapeFuncRegister infer(op_type, CustomAkgOpInferFunc);
131   } else {
132     ::ge::InferShapeFuncRegister infer(op_type, CustomTbeAicpuOpInferFunc);
133   }
134 }
135 
GetRealInputIndices(const CNodePtr & cnode)136 static std::vector<int64_t> GetRealInputIndices(const CNodePtr &cnode) {
137   if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode)) {
138     return std::vector<int64_t>{};
139   }
140   std::vector<int64_t> real_input_indices =
141     common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrDynInputSizes);
142   int64_t count = 0;
143   // construct real input indices based on attribute kAttrDynInputSizes
144   for (size_t i = 0; i < real_input_indices.size(); ++i) {
145     int64_t num_folded_inputs = (real_input_indices[i] < 0 ? 1 : real_input_indices[i]);
146     real_input_indices[i] = count;
147     count += num_folded_inputs;
148   }
149   return real_input_indices;
150 }
151 
GetRealAnfInputIndex(const std::vector<int64_t> & real_input_indices,size_t anf_input_index)152 static inline size_t GetRealAnfInputIndex(const std::vector<int64_t> &real_input_indices, size_t anf_input_index) {
153   // NOTE: anf_input_index start with 1, index 0 corresponding to primitive value node
154   size_t input_index = anf_input_index - 1;
155   size_t real_index =
156     input_index < real_input_indices.size() ? static_cast<size_t>(real_input_indices[input_index]) : input_index;
157   // at last convert `input_index` to anf node input index
158   return real_index + 1;
159 }
160 
IsCustomOp(const OperatorPtr & op) const161 bool OpAdapterImpl::IsCustomOp(const OperatorPtr &op) const {
162   MS_EXCEPTION_IF_NULL(op);
163   auto it = cus_input_map_->find(op->GetOpType());
164   if (it == cus_input_map_->end()) {
165     return false;
166   }
167   return true;
168 }
169 
GenerateCustomOpInputMap(const CusOperatorPtr & op,const PrimitivePtr & prim)170 Status OpAdapterImpl::GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) {
171   MS_EXCEPTION_IF_NULL(op);
172   MS_EXCEPTION_IF_NULL(prim);
173   // Create the map of custom op from input index to input name.
174   mindspore::HashMap<int, std::string> input_map;
175   auto op_type = GetCustomOpType(prim);
176   auto value = GetCustomOpInputNames(prim);
177   if (value == nullptr) {
178     (*cus_output_map_)[op_type] = std::map<int, std::string>{};
179     return NOT_FOUND;
180   }
181 
182   auto input_names = GetValue<const std::vector<std::string>>(value);
183   for (size_t i = 0; i < input_names.size(); ++i) {
184     // input_map begin form 1
185     input_map[i + 1] = input_names[i];
186     op->CustomInputRegister(input_names[i]);
187   }
188 
189   if (cus_input_map_->find(op_type) == cus_input_map_->end()) {
190     (*cus_input_map_)[op_type] = input_map;
191   }
192   return SUCCESS;
193 }
194 
GenerateCustomOpOutputMap(const CusOperatorPtr & op,const PrimitivePtr & prim)195 Status OpAdapterImpl::GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) {
196   MS_EXCEPTION_IF_NULL(op);
197   MS_EXCEPTION_IF_NULL(prim);
198   // Create the map of custom op from output index to output name.
199   std::map<int, std::string> output_map;
200   auto op_type = GetCustomOpType(prim);
201   auto value = prim->GetAttr("output_names");
202   if (value == nullptr) {
203     // generate a empty output_map for it
204     (*cus_output_map_)[op_type] = output_map;
205     return NOT_FOUND;
206   }
207 
208   auto output_names = GetValue<const std::vector<std::string>>(value);
209   for (size_t i = 0; i < output_names.size(); ++i) {
210     // output_map begin form 0
211     output_map[i] = output_names[i];
212     op->CustomOutputRegister(output_names[i]);
213   }
214 
215   if (cus_output_map_->find(op_type) == cus_output_map_->end()) {
216     (*cus_output_map_)[op_type] = output_map;
217   }
218   return SUCCESS;
219 }
220 
GetCustomOpType(const PrimitivePtr & prim) const221 std::string OpAdapterImpl::GetCustomOpType(const PrimitivePtr &prim) const {
222   MS_EXCEPTION_IF_NULL(prim);
223   auto detail_type = GetCustomOpTypeDetail(prim);
224   if (detail_type == CustomOpType::kTbe) {
225     auto func_name = prim->GetAttr("func_name");
226     if (func_name == nullptr) {
227       MS_LOG(ERROR) << "Custom tbe op has no 'func_name' attr.";
228       return "";
229     }
230     return GetValue<std::string>(func_name);
231   }
232   auto value = prim->GetAttr("reg_op_name");
233   if (value == nullptr) {
234     MS_LOG(ERROR) << "Custom op has no reg_op_name attr.";
235     return "";
236   }
237   auto op_type = GetValue<std::string>(value);
238   return op_type;
239 }
240 
GenerateCustomOp(const AnfNodePtr anf)241 OperatorPtr OpAdapterImpl::GenerateCustomOp(const AnfNodePtr anf) {
242   MS_EXCEPTION_IF_NULL(anf);
243   auto node = anf->cast<CNodePtr>();
244   if (node == nullptr) {
245     return nullptr;
246   }
247 
248   if (node->inputs().empty()) {
249     MS_LOG(EXCEPTION) << "length of node inputs is empty";
250   }
251 
252   auto prim = GetValueNode<PrimitivePtr>(node->inputs()[0]);
253   MS_EXCEPTION_IF_NULL(prim);
254   auto op_type = GetCustomOpType(prim);
255   auto op = std::make_shared<::ge::CustomOperator>(node->fullname_with_scope() + op_type, op_type);
256   MS_EXCEPTION_IF_NULL(op);
257   if (GenerateCustomOpInputMap(op, prim) != SUCCESS) {
258     MS_LOG(WARNING) << "Custom op node has no input_names, op[" << prim->name() << "].";
259   }
260 
261   if (GenerateCustomOpOutputMap(op, prim) != SUCCESS) {
262     MS_LOG(WARNING) << "Custom op node has no output_names, op[" << prim->name() << "].";
263   }
264 
265   auto detail_type = GetCustomOpTypeDetail(prim);
266   if (detail_type == CustomOpType::kAkg) {
267     std::vector<std::string> attr_names{"info_path"};
268     op->CustomRequiredAttrRegister(attr_names[0]);
269     op->CustomInferFuncRegister(CustomAkgOpInferFunc);
270     RegisterCustomOp(prim, op_type, attr_names, true);
271   } else if (detail_type == CustomOpType::kTbe || detail_type == CustomOpType::kAiCpu) {
272     auto attr_names = GetCustomOpKernelAttrs(prim);
273     for (const auto &attr_name : attr_names) {
274       op->CustomRequiredAttrRegister(attr_name);
275     }
276     op->CustomInferFuncRegister(CustomTbeAicpuOpInferFunc);
277     RegisterCustomOp(prim, op_type, attr_names, false);
278   } else {
279     MS_LOG(INFO) << "For custom operators, users need to define and implement the Infershape function by themselves.";
280   }
281 
282   return op;
283 }
284 
SetOpSubgraphFunc(const OperatorPtr & op,int index,const std::shared_ptr<std::vector<DfGraph>> & branches)285 Status OpAdapterImpl::SetOpSubgraphFunc(const OperatorPtr &op, int index,
286                                         const std::shared_ptr<std::vector<DfGraph>> &branches) {
287   MS_EXCEPTION_IF_NULL(op);
288   auto it = dyn_subgraph_map_.find(index);
289   if (it != dyn_subgraph_map_.end()) {
290     auto size = branches->size();
291     it->second.create_dyn_subgraph(op, static_cast<unsigned int>(size));
292     for (size_t i = 0; i < size; i++) {
293       it->second.set_subgraph(op, static_cast<unsigned int>(i), std::make_shared<DfGraph>((*branches)[i]));
294     }
295     return SUCCESS;
296   }
297   return NOT_FOUND;
298 }
299 
SetOpSubgraphFunc(const OperatorPtr & op,const std::shared_ptr<std::vector<DfGraph>> & subgraphs)300 Status OpAdapterImpl::SetOpSubgraphFunc(const OperatorPtr &op, const std::shared_ptr<std::vector<DfGraph>> &subgraphs) {
301   MS_EXCEPTION_IF_NULL(op);
302   if (subgraph_map_.size() != subgraphs->size()) {
303     return INVALID_ARGUMENT;
304   }
305   for (size_t i = 0; i < subgraphs->size(); i++) {
306     subgraph_map_.at(i).set_subgraph(op, std::make_shared<DfGraph>((*subgraphs)[i]));
307   }
308   return SUCCESS;
309 }
310 
SetCustomOpInput(const CusOperatorPtr & op,int index,const OperatorPtr & input) const311 Status OpAdapterImpl::SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) const {
312   MS_EXCEPTION_IF_NULL(op);
313   MS_EXCEPTION_IF_NULL(input);
314   auto it = cus_input_map_->find(op->GetOpType());
315   if (it == cus_input_map_->end()) {
316     return NOT_FOUND;
317   }
318   mindspore::HashMap<int, std::string> &input_map = it->second;
319 
320   if ((input_map.find(index) != input_map.end())) {
321     MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << input_map[index];
322     (void)op->SetInput(input_map[index], *input);
323     return SUCCESS;
324   }
325   return NOT_FOUND;
326 }
327 
SetNormalOpInput(const OperatorPtr & op,int index,const OperatorPtr & input)328 Status OpAdapterImpl::SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) {
329   MS_EXCEPTION_IF_NULL(op);
330   auto it = input_map_.find(index);
331   if (input != nullptr && it != input_map_.end()) {
332     MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << it->second.name;
333     it->second.set_op(op, input);
334     return SUCCESS;
335   }
336   return NOT_FOUND;
337 }
338 
setInput(const OperatorPtr & op,int index,const OperatorPtr & input)339 int OpAdapterImpl::setInput(const OperatorPtr &op, int index, const OperatorPtr &input) {
340   if (IsCustomOp(op)) {
341     auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
342     return static_cast<int>(SetCustomOpInput(cus_op, index, input));
343   } else {
344     return static_cast<int>(SetNormalOpInput(op, index, input));
345   }
346 }
347 
SetCustomOpInput(const CusOperatorPtr & op,int index,const OutHandler & handle) const348 Status OpAdapterImpl::SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) const {
349   MS_EXCEPTION_IF_NULL(op);
350   auto it = cus_input_map_->find(op->GetOpType());
351   if (it == cus_input_map_->end()) {
352     return NOT_FOUND;
353   }
354 
355   mindspore::HashMap<int, std::string> &input_map = it->second;
356   if ((handle.op != nullptr) && (input_map.find(index) != input_map.end())) {
357     if (handle.out.empty()) {
358       MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << input_map[index];
359       (void)op->SetInput(input_map[index], *(handle.op));
360     } else {
361       MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":"
362                     << input_map[index];
363       (void)op->SetInput(input_map[index], *(handle.op), handle.out);
364     }
365     return SUCCESS;
366   }
367   return NOT_FOUND;
368 }
369 
SetNormalOpInput(const OperatorPtr & op,int index,const OutHandler & handle)370 Status OpAdapterImpl::SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) {
371   MS_EXCEPTION_IF_NULL(op);
372   auto it = input_map_.find(index);
373   if ((handle.op != nullptr) && (it != input_map_.end())) {
374     if (handle.out.empty()) {
375       MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << it->second.name;
376       it->second.set_op(op, handle.op);
377     } else {
378       MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":"
379                     << it->second.name;
380       it->second.set_handle(op, handle);
381     }
382     return SUCCESS;
383   }
384   return NOT_FOUND;
385 }
386 
setInput(const OperatorPtr & op,int index,const OutHandler & handle)387 int OpAdapterImpl::setInput(const OperatorPtr &op, int index, const OutHandler &handle) {
388   if (IsCustomOp(op)) {
389     auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
390     return static_cast<int>(SetCustomOpInput(cus_op, index, handle));
391   } else {
392     return static_cast<int>(SetNormalOpInput(op, index, handle));
393   }
394 }
395 
setInput(const OperatorPtr & op,int index,const std::shared_ptr<std::vector<OutHandler>> & handler_vec,bool use_create_byindex_func,size_t dyn_index)396 int OpAdapterImpl::setInput(const OperatorPtr &op, int index,
397                             const std::shared_ptr<std::vector<OutHandler>> &handler_vec, bool use_create_byindex_func,
398                             size_t dyn_index) {
399   MS_EXCEPTION_IF_NULL(handler_vec);
400   if (IsCustomOp(op)) {
401     MS_LOG(ERROR) << "Custom Op do not support dynamic input";
402     return static_cast<int>(FAILED);
403   }
404   MS_EXCEPTION_IF_NULL(op);
405   auto it = dyn_input_map_.find(index);
406   if (it != dyn_input_map_.end()) {
407     if (use_create_byindex_func) {
408       it->second.create_dyn_input_by_index(op, static_cast<unsigned int>(handler_vec->size()), dyn_index);
409     } else {
410       it->second.create_dyn_input(op, static_cast<unsigned int>(handler_vec->size()));
411     }
412     for (unsigned int i = 0; i < handler_vec->size(); ++i) {
413       OutHandler h = (*handler_vec)[i];
414       MS_EXCEPTION_IF_NULL(h.op);
415       if (h.out.empty()) {
416         MS_LOG(DEBUG) << "Link op " << h.op->GetName() << " to " << op->GetName() << ":" << it->second.name;
417         it->second.set_op(op, (i), h.op);
418       } else {
419         MS_LOG(DEBUG) << "Link op " << h.op->GetName() << ":" << h.out << " to " << op->GetName() << ":"
420                       << it->second.name;
421         it->second.set_handle(op, i, h);
422       }
423     }
424     return 0;
425   }
426   return static_cast<int>(NOT_FOUND);
427 }
428 
getOutput(const OperatorPtr & op,int index)429 OutHandler OpAdapterImpl::getOutput(const OperatorPtr &op, int index) {
430   MS_EXCEPTION_IF_NULL(op);
431   if (IsCustomOp(op)) {
432     return getCustomOutput(op, index);
433   }
434   return getNormalOutput(op, index);
435 }
436 
getOutputs(const OperatorPtr & op) const437 std::vector<OutHandler> OpAdapterImpl::getOutputs(const OperatorPtr &op) const {
438   if (IsCustomOp(op)) {
439     return getCustomOutputs(op);
440   }
441   return getNormalOutputs(op);
442 }
443 
getCustomOutput(const OperatorPtr & op,int index) const444 OutHandler OpAdapterImpl::getCustomOutput(const OperatorPtr &op, int index) const {
445   MS_EXCEPTION_IF_NULL(op);
446   auto it = cus_output_map_->find(op->GetOpType());
447   if (it == cus_output_map_->end()) {
448     MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT is not supported!";
449     return OutHandler();
450   }
451 
452   std::map<int, std::string> &output_map = it->second;
453 
454   if ((output_map.find(index) != output_map.end())) {
455     return OutHandler(op, output_map[index]);
456   }
457   MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT index(" << index << ")!";
458   return OutHandler();
459 }
460 
getNormalOutput(const OperatorPtr & op,int index)461 OutHandler OpAdapterImpl::getNormalOutput(const OperatorPtr &op, int index) {
462   MS_EXCEPTION_IF_NULL(op);
463   if (!dyn_output_map_.empty() && !output_map_.empty()) {
464     MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!";
465     return OutHandler();
466   }
467   auto it = output_map_.find(index);
468   if (it != output_map_.end()) {
469     return OutHandler(op, it->second.name);
470   } else if (!dyn_output_map_.empty()) {
471     return OutHandler(op, dyn_output_map_.begin()->second.name + std::to_string(index));
472   } else {
473     MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT and DYN_OUTPUT index(" << index << ")!";
474     return OutHandler();
475   }
476 }
477 
getNormalOutputs(const OperatorPtr & op) const478 std::vector<OutHandler> OpAdapterImpl::getNormalOutputs(const OperatorPtr &op) const {
479   MS_EXCEPTION_IF_NULL(op);
480   if (!dyn_output_map_.empty() && !output_map_.empty()) {
481     MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!";
482     return std::vector<OutHandler>{};
483   }
484   std::vector<OutHandler> handles;
485   std::transform(output_map_.begin(), output_map_.end(), std::back_inserter(handles),
486                  [&op](const auto &item) { return OutHandler(op, item.second.name); });
487   if (!dyn_output_map_.empty()) {
488     auto dyn_output_name = dyn_output_map_.begin()->second.name;
489     auto dyn_output_size = op->GetDynamicOutputNum(dyn_output_name);
490     for (int i = 0; i < dyn_output_size; i++) {
491       handles.emplace_back(OutHandler(op, dyn_output_name + std::to_string(i)));
492     }
493   }
494   return handles;
495 }
496 
getCustomOutputs(const OperatorPtr & op) const497 std::vector<OutHandler> OpAdapterImpl::getCustomOutputs(const OperatorPtr &op) const {
498   MS_EXCEPTION_IF_NULL(op);
499   std::vector<OutHandler> handles;
500   auto it = cus_output_map_->find(op->GetOpType());
501   if (it == cus_output_map_->end()) {
502     MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ")'s OUTPUT is not supported!";
503     return handles;
504   }
505   std::transform(it->second.begin(), it->second.end(), std::back_inserter(handles),
506                  [&op](const auto &item) { return OutHandler(op, item.second); });
507   return handles;
508 }
509 
UpdateSingleOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const std::string & format)510 Status OpAdapterImpl::UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp,
511                                              const TypePtr &type, const std::string &format) {
512   MS_EXCEPTION_IF_NULL(type);
513 
514   auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(shp), type, format);
515   if (desc == nullptr) {
516     MS_LOG(ERROR) << "Update output descriptor failed!";
517     return FAILED;
518   }
519 
520   if (IsCustomOp(op)) {
521     if (cus_output_map_->find(op->GetOpType()) == cus_output_map_->end() ||
522         ((*cus_output_map_)[op->GetOpType()].empty())) {
523       MS_LOG(ERROR) << "This op does not create custom output map";
524       return FAILED;
525     }
526     auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
527     MS_EXCEPTION_IF_NULL(cus_op);
528     std::map<int, std::string> output_map = (*cus_output_map_)[op->GetOpType()];
529     (void)cus_op->UpdateOutputDesc(output_map[0], *desc);
530     std::vector<std::vector<int64_t>> out_shapes{desc->GetShape().GetDims()};
531     std::vector<int32_t> out_formats{static_cast<int32_t>(desc->GetFormat())};
532     std::vector<int32_t> out_types{static_cast<int32_t>(desc->GetDataType())};
533     (void)cus_op->SetAttr("output_shapes", out_shapes);
534     (void)cus_op->SetAttr("output_formats", out_formats);
535     (void)cus_op->SetAttr("output_types", out_types);
536   } else {
537     if (!output_map_.empty()) {
538       output_map_.begin()->second.update_out_desc(op, *desc);
539     } else if (!dyn_output_map_.empty()) {
540       dyn_output_map_.begin()->second.update_dyn_output_desc(op, 0, *desc);
541     } else {
542       MS_LOG(DEBUG) << "This op does not have output map";
543       return FAILED;
544     }
545   }
546   return SUCCESS;
547 }
548 
GetCustomOpOutputSize(const CusOperatorPtr & cus_op) const549 size_t OpAdapterImpl::GetCustomOpOutputSize(const CusOperatorPtr &cus_op) const {
550   MS_EXCEPTION_IF_NULL(cus_op);
551   if (cus_output_map_->find(cus_op->GetOpType()) == cus_output_map_->end()) {
552     MS_LOG(ERROR) << "This op does not create custom output map";
553     return 0;
554   }
555   size_t output_size = (*cus_output_map_)[cus_op->GetOpType()].size();
556   return output_size;
557 }
558 
CreateOutputDesc(const abstract::ShapePtr & shape_ptr,const TypePtr & type,const std::string & format) const559 std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type,
560                                                               const std::string &format) const {
561   if (type == nullptr) {
562     MS_LOG(ERROR) << "Type ptr is nullptr";
563     return nullptr;
564   }
565 
566   TypeId me_type = type->type_id();
567   if (kObjectTypeTensorType == me_type) {
568     me_type = dyn_cast<TensorType>(type)->element()->type_id();
569   }
570 
571   return TransformUtil::GetGeTensorDesc((shape_ptr == nullptr) ? ShapeVector{} : shape_ptr->shape(), me_type, format);
572 }
573 
UpdateMultiOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const std::string & format)574 Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp,
575                                             const TypePtr &type, const std::string &format) {
576   auto tuple_shp = dyn_cast<abstract::TupleShape>(shp);
577   MS_EXCEPTION_IF_NULL(tuple_shp);
578 
579   size_t output_size = 0;
580   bool is_custom_op = IsCustomOp(op);
581   if (is_custom_op) {
582     output_size = GetCustomOpOutputSize(std::dynamic_pointer_cast<CustomOperator>(op));
583   } else {
584     output_size = output_map_.empty()
585                     ? static_cast<size_t>(op->GetDynamicOutputNum(dyn_output_map_.begin()->second.name))
586                     : output_map_.size();
587   }
588 
589   if (output_size == 0) {
590     MS_LOG(DEBUG) << "This op does not have output map";
591     return FAILED;
592   }
593 
594   // There are scenarios that output_size is greater than tuple_shape size.
595   // Reserved outputs exist in output_map taking BatchNormGrad as an example.
596   if (output_size < tuple_shp->shape().size()) {
597     MS_LOG(INFO) << "output_map is smaller than tuple_shape size, node: " << op->GetName();
598     return FAILED;
599   }
600 
601   std::vector<std::vector<int64_t>> out_shapes;
602   std::vector<int32_t> out_formats;
603   std::vector<int32_t> out_types;
604   for (size_t i = 0; i < tuple_shp->shape().size(); ++i) {
605     auto tuple_type = dyn_cast<Tuple>(type);
606     MS_EXCEPTION_IF_NULL(tuple_type);
607     TypePtr type_elem = tuple_type->elements()[i];
608     if (type_elem == nullptr) {
609       MS_LOG(ERROR) << "Type ptr is nullptr";
610       return FAILED;
611     }
612     TypeId me_type = type_elem->type_id();
613     if (kObjectTypeTensorType == me_type) {
614       me_type = dyn_cast<TensorType>(type_elem)->element()->type_id();
615     }
616     if (me_type == kMetaTypeNone) {
617       continue;
618     }
619 
620     auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(tuple_shp->shape()[i]), type_elem, format);
621     if (desc == nullptr) {
622       MS_LOG(WARNING) << "Create op: " << op->GetName() << " output descriptor failed!";
623       return FAILED;
624     }
625 
626     if (is_custom_op) {
627       (void)std::dynamic_pointer_cast<CustomOperator>(op)->UpdateOutputDesc((*cus_output_map_)[op->GetOpType()][i],
628                                                                             *desc);
629       out_shapes.push_back(desc->GetShape().GetDims());
630       out_formats.push_back(static_cast<int32_t>(desc->GetFormat()));
631       out_types.push_back(static_cast<int32_t>(desc->GetDataType()));
632     } else {
633       auto it = output_map_.find(i);
634       if (it != output_map_.end()) {
635         it->second.update_out_desc(op, *desc);
636       } else if (!dyn_output_map_.empty()) {
637         dyn_output_map_.begin()->second.update_dyn_output_desc(op, static_cast<unsigned int>(i), *desc);
638       }
639     }
640   }
641   if (is_custom_op) {
642     (void)op->SetAttr("output_shapes", out_shapes);
643     (void)op->SetAttr("output_formats", out_formats);
644     (void)op->SetAttr("output_types", out_types);
645   }
646   return SUCCESS;
647 }
648 
CreateNodeDesc(const AnfNodePtr & node,const std::string & format) const649 std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &node, const std::string &format) const {
650   MS_EXCEPTION_IF_NULL(node);
651   TypeId me_type = node->Type()->type_id();
652   if (kObjectTypeTensorType == me_type) {
653     me_type = dyn_cast<TensorType>(node->Type())->element()->type_id();
654   }
655   if (me_type <= kNumberTypeBegin || me_type >= kNumberTypeEnd) {
656     return nullptr;
657   }
658 
659   std::vector<int64_t> shape;
660   auto shape_ptr = dyn_cast<abstract::Shape>(node->Shape());
661   if (shape_ptr != nullptr) {
662     shape = shape_ptr->shape();
663   }
664 
665   auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format);
666   if (desc == nullptr) {
667     MS_LOG(ERROR) << "Update output descriptor failed!";
668     return nullptr;
669   }
670   return desc;
671 }
672 
UpdateNormalOpInputDesc(const OperatorPtr & op,const AnfNodePtr & node,const std::string format)673 void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format) {
674   if (op == nullptr) {
675     MS_LOG(ERROR) << "op is nullptr";
676     return;
677   }
678   MS_EXCEPTION_IF_NULL(node);
679   std::map<size_t, size_t> real_input_map;
680   if (!dyn_input_map_.empty() && common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, node->cast<CNodePtr>())) {
681     std::vector<int64_t> dyn_input_sizes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrDynInputSizes);
682     if (!dyn_input_sizes.empty()) {
683       size_t input_index = kIndex1;
684       for (size_t i = 0; i < dyn_input_sizes.size(); ++i) {
685         int64_t dyn_input_size = dyn_input_sizes[i];
686         if (dyn_input_size < 0) {
687           real_input_map[input_index] = i + kIndex1;
688           input_index += 1;
689         } else {
690           input_index += dyn_input_size;
691         }
692       }
693     }
694   }
695 
696   auto inputs = node->cast<CNodePtr>()->inputs();
697   for (size_t i = 1; i < inputs.size(); ++i) {
698     size_t real_input_index = i;
699     if (!real_input_map.empty()) {
700       auto iter = real_input_map.find(i);
701       if (iter != real_input_map.end()) {
702         real_input_index = iter->second;
703       } else {
704         continue;
705       }
706     }
707     auto it = input_map_.find(real_input_index);
708     if (it != input_map_.end()) {
709       auto desc = CreateNodeDesc(inputs[i], format);
710       if (desc == nullptr) {
711         continue;
712       }
713 
714       it->second.update_input_desc(op, *desc);
715     }
716   }
717 }
718 
UpdateCustomOpInputDesc(const CusOperatorPtr & op,const AnfNodePtr & node,const std::string format) const719 void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node,
720                                             const std::string format) const {
721   if (op == nullptr) {
722     MS_LOG(ERROR) << "op is nullptr";
723     return;
724   }
725   MS_EXCEPTION_IF_NULL(node);
726 
727   if (cus_input_map_->find(op->GetOpType()) == cus_input_map_->end() || ((*cus_input_map_)[op->GetOpType()].empty())) {
728     MS_LOG(ERROR) << "This op does not create custom input map";
729     return;
730   }
731 
732   mindspore::HashMap<int, std::string> &input_map = (*cus_input_map_)[op->GetOpType()];
733   auto inputs = node->cast<CNodePtr>()->inputs();
734   for (size_t i = 1; i < inputs.size(); ++i) {
735     if (input_map.find(i) != input_map.end()) {
736       auto desc = CreateNodeDesc(inputs[i], format);
737       if (desc == nullptr) {
738         continue;
739       }
740       (void)op->UpdateInputDesc(input_map[i], *desc);
741     }
742   }
743 }
744 
updateInputDesc(const OperatorPtr & op,const AnfNodePtr & node)745 void OpAdapterImpl::updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) {
746   MS_EXCEPTION_IF_NULL(op);
747   MS_EXCEPTION_IF_NULL(node);
748   std::string format = GetOpIOFormat(node);
749   if (IsCustomOp(op)) {
750     auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
751     UpdateCustomOpInputDesc(cus_op, node, format);
752   } else {
753     UpdateNormalOpInputDesc(op, node, format);
754   }
755 }
756 
updateOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const AnfNodePtr & node)757 void OpAdapterImpl::updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
758                                      const AnfNodePtr &node) {
759   if (op == nullptr) {
760     MS_LOG(ERROR) << "op is nullptr";
761     return;
762   }
763   MS_EXCEPTION_IF_NULL(node);
764   MS_LOG(DEBUG) << "Op name is " << op->GetName() << " anf is " << node->DebugString() << ", shape: " << shp->ToString()
765                 << ", type: " << type;
766   if (AnfUtils::GetOutputTensorNum(node) == 0) {
767     return;
768   }
769 
770   auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp);
771   auto no_shape_ptr = dyn_cast<abstract::NoShape>(shp);
772   std::string format = GetOpIOFormat(node);
773 
774   if ((normal_shape_ptr != nullptr) || (no_shape_ptr != nullptr)) {
775     if (UpdateSingleOutputDesc(op, shp, type, format) != SUCCESS) {
776       return;
777     }
778   } else if (dyn_cast<abstract::TupleShape>(shp) != nullptr) {
779     if (UpdateMultiOutputDesc(op, shp, type, format) != SUCCESS) {
780       return;
781     }
782   } else {
783     MS_LOG(WARNING) << "Update output desc failed, unknown output shape type";
784     return;
785   }
786   MS_EXCEPTION_IF_NULL(node);
787   if (!node->isa<CNode>()) {
788     return;
789   }
790 
791   // Need to update input_desc while the output_desc is updated
792   updateInputDesc(op, node);
793 }
794 
setAttr(const OperatorPtr & op,const std::string & attr_key,const ValuePtr & attr_value)795 int OpAdapterImpl::setAttr(const OperatorPtr &op, const std::string &attr_key, const ValuePtr &attr_value) {
796   auto it = attr_map_.find(attr_key);
797   if (it != attr_map_.end()) {
798     // switch case for each avalilable attribute type
799     MS_LOG(DEBUG) << "Op: " << op->GetName() << ", set attr: " << attr_key << "(" << it->second.name
800                   << "), value: " << attr_value->ToString();
801     adpt_->AddAttrToDrawGraph(attr_key + std::string("=") + attr_value->ToString());
802     it->second.set_attr(op, attr_value);
803     return 0;
804   }
805   return static_cast<int>(NOT_FOUND);
806 }
807 
setAttr(const OperatorPtr & op,const uint32_t & input_idx,const ValuePtr & attr_value)808 int OpAdapterImpl::setAttr(const OperatorPtr &op, const uint32_t &input_idx, const ValuePtr &attr_value) {
809   auto it = input_attr_map_.find(input_idx);
810   if (it != input_attr_map_.end()) {
811     it->second.set_attr(op, attr_value);
812     return static_cast<int>(SUCCESS);
813   }
814   return static_cast<int>(NOT_FOUND);
815 }
816 
getAttr(const OperatorPtr & op,const std::string & attr_key,ValuePtr * attr_value)817 int OpAdapterImpl::getAttr(const OperatorPtr &op, const std::string &attr_key, ValuePtr *attr_value) {
818   MS_EXCEPTION_IF_NULL(attr_value);
819   auto it = attr_map_.find(attr_key);
820   if (it != attr_map_.end()) {
821     it->second.get_attr(op, attr_value);
822     return static_cast<int>(SUCCESS);
823   }
824   return static_cast<int>(NOT_FOUND);
825 }
826 
getAttr(const OperatorPtr & op,uint32_t input_idx,ValuePtr * attr_value)827 int OpAdapterImpl::getAttr(const OperatorPtr &op, uint32_t input_idx, ValuePtr *attr_value) {
828   MS_EXCEPTION_IF_NULL(attr_value);
829   auto it = input_attr_map_.find(input_idx);
830   if (it != input_attr_map_.end()) {
831     it->second.get_attr(op, attr_value);
832     return static_cast<int>(SUCCESS);
833   }
834   return static_cast<int>(NOT_FOUND);
835 }
836 
SetCustomOpAttr(const CusOperatorPtr & op,const PrimitivePtr & prim) const837 int OpAdapterImpl::SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) const {
838   enum ValueType {
839     SINGLE_VALUE = 0,
840     SEQUEUE_VALUE,
841     UNKNOWN_VALUE,
842   };
843 
844   MS_EXCEPTION_IF_NULL(prim);
845   MS_EXCEPTION_IF_NULL(op);
846 
847   ValueType value_type = SINGLE_VALUE;
848   static std::unordered_set<std::string> excluded_attr{"IsFeatureMapInputList", "IsFeatureMapOutput"};
849   for (auto item : prim->attrs()) {
850     if (excluded_attr.find(item.first) != excluded_attr.end()) {
851       continue;
852     }
853     if (item.second->isa<Int32Imm>()) {
854       (void)op->SetAttr(item.first, GetValue<int32_t>(item.second));
855     } else if (item.second->isa<Int64Imm>()) {
856       (void)op->SetAttr(item.first, GetValue<int64_t>(item.second));
857     } else if (item.second->isa<StringImm>()) {
858       (void)op->SetAttr(item.first, GetValue<std::string>(item.second));
859     } else if (item.second->isa<BoolImm>()) {
860       (void)op->SetAttr(item.first, GetValue<bool>(item.second));
861     } else if (item.second->isa<FP32Imm>()) {
862       (void)op->SetAttr(item.first, GetValue<float>(item.second));
863     } else if (item.second->isa<ValueSequence>()) {
864       value_type = SEQUEUE_VALUE;
865       auto val_seq = item.second->cast<ValueSequencePtr>();
866       if (val_seq->size() == 0) {
867         std::vector<int64_t> value;
868         (void)op->SetAttr(item.first, value);
869         continue;
870       }
871       if ((*val_seq)[0]->isa<StringImm>()) {
872         (void)op->SetAttr(item.first, GetValue<const std::vector<std::string>>(item.second));
873       } else if ((*val_seq)[0]->isa<FP32Imm>()) {
874         (void)op->SetAttr(item.first, GetValue<const std::vector<float>>(item.second));
875       } else if ((*val_seq)[0]->isa<Int32Imm>()) {
876         (void)op->SetAttr(item.first, GetValue<const std::vector<int32_t>>(item.second));
877       } else if ((*val_seq)[0]->isa<Int64Imm>()) {
878         (void)op->SetAttr(item.first, GetValue<const std::vector<int64_t>>(item.second));
879       } else if ((*val_seq)[0]->isa<BoolImm>()) {
880         (void)op->SetAttr(item.first, GetValue<const std::vector<bool>>(item.second));
881       } else {
882         MS_LOG(EXCEPTION) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name()
883                           << ", attr name: " << item.first << ", value: " << item.second->ToString();
884       }
885     } else {
886       MS_LOG(WARNING) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name()
887                       << ", attr name: " << item.first << ", value: " << item.second->ToString();
888       return static_cast<int>(NOT_FOUND);
889     }
890 
891     if (value_type == SINGLE_VALUE) {
892       adpt_->AddAttrToDrawGraph(item.first + std::string("=") + item.second->ToString());
893     } else if (value_type == SEQUEUE_VALUE) {
894       adpt_->AddAttrToDrawGraph(item.first + std::string("=") + "[...]");
895     }
896   }
897   return 0;
898 }
899 
GetOpAttrList(const OperatorPtr & op) const900 std::map<std::string, ValuePtr> OpAdapterImpl::GetOpAttrList(const OperatorPtr &op) const {
901   std::map<std::string, ValuePtr> attr_list;
902   for (auto &it : attr_map_) {
903     ValuePtr value = nullptr;
904     it.second.get_attr(op, &value);
905     (void)attr_list.emplace(it.second.name, value);
906   }
907   for (auto &it : input_attr_map_) {
908     ValuePtr value = nullptr;
909     it.second.get_attr(op, &value);
910     (void)attr_list.emplace(it.second.name, value);
911   }
912   return attr_list;
913 }
914 
GetNormalOpAttrList(const OperatorPtr & op,const AnfNodePtr & node) const915 std::map<std::string, ValuePtr> OpAdapterImpl::GetNormalOpAttrList(const OperatorPtr &op,
916                                                                    const AnfNodePtr &node) const {
917   MS_EXCEPTION_IF_NULL(node);
918   if (!node->isa<CNode>() || node->cast<CNodePtr>() == nullptr) {
919     return {};
920   }
921   auto cnode = node->cast<CNodePtr>();
922   auto &inputs = cnode->inputs();
923   if (inputs.empty() || !IsValueNode<Primitive>(inputs[0])) {
924     return {};
925   }
926 
927   auto prim = GetValueNode<PrimitivePtr>(inputs[0]);
928   std::map<std::string, ValuePtr> attr_list;
929   for (auto &it : attr_map_) {
930     // set attr from extra_attr
931     auto it_extra = extra_attr_->find(it.first);
932     if (it_extra != extra_attr_->end()) {
933       auto value = it_extra->second;
934       (void)attr_list.emplace(it.second.name, value);
935     } else {
936       auto value = prim->GetAttr(it.first);
937       it.second.get_attr(op, &value);
938       (void)attr_list.emplace(it.second.name, value);
939     }
940   }
941 
942   auto real_input_indices = GetRealInputIndices(cnode);
943   // set attr from const input
944   for (auto &it : input_attr_map_) {
945     size_t cur_idx = GetRealAnfInputIndex(real_input_indices, it.first);
946     if (inputs.size() <= cur_idx || !inputs[cur_idx]->isa<ValueNode>()) {
947       continue;
948     }
949     auto const_value = GetValueNode(inputs[cur_idx]);
950     MS_LOG(DEBUG) << "Get input attr: input_" << cur_idx << "(" << it.second.name
951                   << "), value: " << const_value->ToString();
952     if (const_value->isa<None>()) {
953       continue;
954     }
955     (void)attr_list.emplace(it.second.name, const_value);
956   }
957 
958   // Get need convert to input's attr
959   for (auto &it : attr_input_map_) {
960     auto value = prim->GetAttr(it.first);
961     if (value == nullptr) {
962       continue;
963     }
964     (void)attr_list.emplace(it.first, value);
965   }
966   return attr_list;
967 }
968 
SetNormalOpAttr(const OperatorPtr & op,const PrimitivePtr & prim)969 int OpAdapterImpl::SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) {
970   MS_EXCEPTION_IF_NULL(prim);
971   MS_EXCEPTION_IF_NULL(op);
972   for (auto &it : attr_map_) {
973     if (attr_input_map_.count(it.first) != 0) {
974       MS_LOG(WARNING) << "Attr: " << it.first << " will convert to input, please del it from ATTR_MAP.";
975       continue;
976     }
977     auto value = prim->GetAttr(it.first);
978     if (value != nullptr) {
979       // convert parts of attr to str eg. data_format or change ir attr to op attr eg. axis[0]
980       (void)CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), it.first, &value);
981       (void)CheckAndConvertUtils::CheckIrAttrtoOpAttr(prim->name(), it.first, &value);
982       // set attr from primitive
983       int ret = setAttr(op, it.first, value);
984       if (ret != 0) {
985         return ret;
986       }
987     } else {
988       // set attr from extra_attr
989       auto it_extra = extra_attr_->find(it.first);
990       if (it_extra != extra_attr_->end()) {
991         int ret = setAttr(op, it.first, it_extra->second);
992         if (ret != 0) {
993           return ret;
994         }
995       }
996     }
997   }
998   return 0;
999 }
1000 
SetNoFoldingOpAttr(const OperatorPtr & op,const PrimitivePtr & prim)1001 int OpAdapterImpl::SetNoFoldingOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) {
1002   MS_EXCEPTION_IF_NULL(prim);
1003   MS_EXCEPTION_IF_NULL(op);
1004   op->SetAttr("no_need_constant_folding", true);
1005   return SetNormalOpAttr(op, prim);
1006 }
1007 
setAttr(const OperatorPtr & op,const PrimitivePtr & prim)1008 int OpAdapterImpl::setAttr(const OperatorPtr &op, const PrimitivePtr &prim) {
1009   int ret = 0;
1010   if (IsCustomPrim(prim)) {
1011     auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
1012     ret = SetCustomOpAttr(cus_op, prim);
1013   } else if (IsNoNeedConstantFoldCNode(prim)) {
1014     ret = SetNoFoldingOpAttr(op, prim);
1015   } else {
1016     ret = SetNormalOpAttr(op, prim);
1017   }
1018   return ret;
1019 }
1020 
setAttr(const OperatorPtr & op,const AnfNodePtr & node)1021 int OpAdapterImpl::setAttr(const OperatorPtr &op, const AnfNodePtr &node) {
1022   // no attribute for lonely node
1023   MS_EXCEPTION_IF_NULL(node);
1024   if (!node->isa<CNode>() || node->cast<CNodePtr>() == nullptr) {
1025     return 0;
1026   }
1027 
1028   auto cnode = node->cast<CNodePtr>();
1029   auto &inputs = cnode->inputs();
1030   if (inputs.empty()) {
1031     return 0;
1032   }
1033 
1034   // get Attr T from abstract of anfnode first,
1035   // if attr "T" appears in primitive, the primitive T will cover this one
1036   if (attr_map_.find("T") != attr_map_.end()) {
1037     // get dtype from inputs[1], if the node has no inputs, set the attr T with output dtype
1038     TypePtr type;
1039     if (inputs.size() > 1) {
1040       type = inputs[1]->Type();
1041     } else {
1042       type = node->Type();
1043     }
1044     if (type != nullptr) {
1045       (void)setAttr(op, "T", MakeValue(type));
1046     }
1047   }
1048 
1049   // set attr from primitive and ExtraAttr
1050   if (IsValueNode<Primitive>(inputs[0])) {
1051     // set attr from primitive
1052     PrimitivePtr prim = GetValueNode<PrimitivePtr>(inputs[0]);
1053     int ret = setAttr(op, prim);
1054     if (ret != 0) {
1055       return ret;
1056     }
1057   }
1058 
1059   auto real_input_indices = GetRealInputIndices(cnode);
1060   // set attr from const input
1061   for (auto &it : input_attr_map_) {
1062     size_t cur_idx = GetRealAnfInputIndex(real_input_indices, it.first);
1063     if (inputs.size() <= cur_idx || !inputs[cur_idx]->isa<ValueNode>()) {
1064       continue;
1065     }
1066 
1067     auto const_value = GetValueNode(inputs[cur_idx]);
1068     MS_LOG(INFO) << "Set attr: input_" << cur_idx << "(" << it.second.name << "), value: " << const_value->ToString();
1069     if (const_value->isa<None>()) {
1070       continue;
1071     }
1072     if (const_value->isa<mindspore::tensor::Tensor>()) {
1073       auto tensorptr = const_value->cast<mindspore::tensor::TensorPtr>();
1074       const_value = CreateValueFromTensor(tensorptr);
1075     }
1076 
1077     adpt_->AddAttrToDrawGraph(it.second.name + std::string("=") + const_value->ToString());
1078     it.second.set_attr(op, const_value);
1079   }
1080   return 0;
1081 }
1082 }  // namespace transform
1083 }  // namespace mindspore
1084