1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "transform/graph_ir/op_adapter.h"
17
18 #include <algorithm>
19 #include <map>
20 #include <string>
21 #include <unordered_set>
22 #include "utils/check_convert_utils.h"
23 #include "op_proto/inc/split_combination_ops.h"
24 #include "graph/operator_factory.h"
25 #include "include/common/utils/convert_utils.h"
26 #include "utils/anf_utils.h"
27 #include "include/common/utils/anfalgo.h"
28
29 namespace mindspore {
30 namespace transform {
31 ge::graphStatus CustomAkgOpInferFunc(ge::Operator &op);
32
33 ge::graphStatus CustomTbeAicpuOpInferFunc(ge::Operator &op);
34
35 enum class CustomOpType { kUnKnown, kAkg, kTbe, kAiCpu };
36
GetCustomOpTypeDetail(const PrimitivePtr & prim)37 CustomOpType GetCustomOpTypeDetail(const PrimitivePtr &prim) {
38 if (prim == nullptr) {
39 return CustomOpType::kUnKnown;
40 }
41 auto type = prim->GetAttr("type");
42 if (type != nullptr && GetValue<std::string>(type) == "GraphKernel") {
43 return CustomOpType::kAkg;
44 }
45 auto func_type = prim->GetAttr("func_type");
46 if (func_type != nullptr) {
47 auto func_type_value = GetValue<std::string>(func_type);
48 if (func_type_value == "tbe") {
49 return CustomOpType::kTbe;
50 } else if (func_type_value == "aicpu") {
51 return CustomOpType::kAiCpu;
52 }
53 }
54 return CustomOpType::kUnKnown;
55 }
56
GetCustomOpInputNames(const PrimitivePtr & prim)57 ValuePtr GetCustomOpInputNames(const PrimitivePtr &prim) {
58 MS_EXCEPTION_IF_NULL(prim);
59 // MS Custom op 'input_names' include attr names, which is not expected
60 auto value = prim->GetAttr("pure_input_names");
61 if (value == nullptr) {
62 value = prim->GetAttr("input_names");
63 }
64 return value;
65 }
66
GetCustomOpKernelAttrs(const PrimitivePtr & prim)67 std::vector<std::string> GetCustomOpKernelAttrs(const PrimitivePtr &prim) {
68 MS_EXCEPTION_IF_NULL(prim);
69 std::vector<std::string> res;
70 auto op_type = GetCustomOpTypeDetail(prim);
71 auto attr_names = prim->GetAttr("attr_names");
72 if (attr_names != nullptr) {
73 auto names = GetValue<std::vector<std::string>>(attr_names);
74 std::vector<std::string> optional_attrs;
75 auto attr_ptr = prim->GetAttr("missing_optional_attrs");
76 if (attr_ptr != nullptr) {
77 optional_attrs = GetValue<std::vector<std::string>>(attr_ptr);
78 }
79 for (const auto &name : names) {
80 if (!prim->HasAttr(name)) {
81 // optional attr can have no value, but required attr must have value
82 if (std::find(optional_attrs.begin(), optional_attrs.end(), name) == std::end(optional_attrs)) {
83 MS_LOG(ERROR) << "Custom op attr '" << name << "' value not set";
84 }
85 } else {
86 if (op_type == CustomOpType::kAiCpu && name == "cust_aicpu") {
87 continue;
88 }
89 res.push_back(name);
90 }
91 }
92 }
93 return res;
94 }
95
RegisterCustomOp(const PrimitivePtr & prim,const std::string & op_type,const std::vector<std::string> & attr_names,bool is_akg)96 void RegisterCustomOp(const PrimitivePtr &prim, const std::string &op_type, const std::vector<std::string> &attr_names,
97 bool is_akg) {
98 if (ge::OperatorFactory::IsExistOp(op_type)) {
99 return;
100 }
101 MS_EXCEPTION_IF_NULL(prim);
102 auto input_names_v = GetCustomOpInputNames(prim);
103 MS_EXCEPTION_IF_NULL(input_names_v);
104 auto input_names = GetValue<std::vector<std::string>>(input_names_v);
105 auto output_names_v = prim->GetAttr("output_names");
106 MS_EXCEPTION_IF_NULL(output_names_v);
107 auto output_names = GetValue<std::vector<std::string>>(output_names_v);
108 // Register op create function, which describes how to create a custom op
109 ::ge::OperatorCreatorRegister op_create_reg(
110 op_type, [op_type, input_names, output_names, attr_names, is_akg](const std::string &name) {
111 auto op = ge::CustomOperator(name, op_type);
112 for (const auto &in_name : input_names) {
113 op.CustomInputRegister(in_name);
114 }
115 for (const auto &out_name : output_names) {
116 op.CustomOutputRegister(out_name);
117 }
118 for (const auto &attr_name : attr_names) {
119 op.CustomRequiredAttrRegister(attr_name);
120 }
121 if (is_akg) {
122 op.CustomInferFuncRegister(CustomAkgOpInferFunc);
123 } else {
124 op.CustomInferFuncRegister(CustomTbeAicpuOpInferFunc);
125 }
126 return op;
127 });
128 // Register op infer shape function
129 if (is_akg) {
130 ::ge::InferShapeFuncRegister infer(op_type, CustomAkgOpInferFunc);
131 } else {
132 ::ge::InferShapeFuncRegister infer(op_type, CustomTbeAicpuOpInferFunc);
133 }
134 }
135
GetRealInputIndices(const CNodePtr & cnode)136 static std::vector<int64_t> GetRealInputIndices(const CNodePtr &cnode) {
137 if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode)) {
138 return std::vector<int64_t>{};
139 }
140 std::vector<int64_t> real_input_indices =
141 common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrDynInputSizes);
142 int64_t count = 0;
143 // construct real input indices based on attribute kAttrDynInputSizes
144 for (size_t i = 0; i < real_input_indices.size(); ++i) {
145 int64_t num_folded_inputs = (real_input_indices[i] < 0 ? 1 : real_input_indices[i]);
146 real_input_indices[i] = count;
147 count += num_folded_inputs;
148 }
149 return real_input_indices;
150 }
151
GetRealAnfInputIndex(const std::vector<int64_t> & real_input_indices,size_t anf_input_index)152 static inline size_t GetRealAnfInputIndex(const std::vector<int64_t> &real_input_indices, size_t anf_input_index) {
153 // NOTE: anf_input_index start with 1, index 0 corresponding to primitive value node
154 size_t input_index = anf_input_index - 1;
155 size_t real_index =
156 input_index < real_input_indices.size() ? static_cast<size_t>(real_input_indices[input_index]) : input_index;
157 // at last convert `input_index` to anf node input index
158 return real_index + 1;
159 }
160
IsCustomOp(const OperatorPtr & op) const161 bool OpAdapterImpl::IsCustomOp(const OperatorPtr &op) const {
162 MS_EXCEPTION_IF_NULL(op);
163 auto it = cus_input_map_->find(op->GetOpType());
164 if (it == cus_input_map_->end()) {
165 return false;
166 }
167 return true;
168 }
169
GenerateCustomOpInputMap(const CusOperatorPtr & op,const PrimitivePtr & prim)170 Status OpAdapterImpl::GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) {
171 MS_EXCEPTION_IF_NULL(op);
172 MS_EXCEPTION_IF_NULL(prim);
173 // Create the map of custom op from input index to input name.
174 mindspore::HashMap<int, std::string> input_map;
175 auto op_type = GetCustomOpType(prim);
176 auto value = GetCustomOpInputNames(prim);
177 if (value == nullptr) {
178 (*cus_output_map_)[op_type] = std::map<int, std::string>{};
179 return NOT_FOUND;
180 }
181
182 auto input_names = GetValue<const std::vector<std::string>>(value);
183 for (size_t i = 0; i < input_names.size(); ++i) {
184 // input_map begin form 1
185 input_map[i + 1] = input_names[i];
186 op->CustomInputRegister(input_names[i]);
187 }
188
189 if (cus_input_map_->find(op_type) == cus_input_map_->end()) {
190 (*cus_input_map_)[op_type] = input_map;
191 }
192 return SUCCESS;
193 }
194
GenerateCustomOpOutputMap(const CusOperatorPtr & op,const PrimitivePtr & prim)195 Status OpAdapterImpl::GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) {
196 MS_EXCEPTION_IF_NULL(op);
197 MS_EXCEPTION_IF_NULL(prim);
198 // Create the map of custom op from output index to output name.
199 std::map<int, std::string> output_map;
200 auto op_type = GetCustomOpType(prim);
201 auto value = prim->GetAttr("output_names");
202 if (value == nullptr) {
203 // generate a empty output_map for it
204 (*cus_output_map_)[op_type] = output_map;
205 return NOT_FOUND;
206 }
207
208 auto output_names = GetValue<const std::vector<std::string>>(value);
209 for (size_t i = 0; i < output_names.size(); ++i) {
210 // output_map begin form 0
211 output_map[i] = output_names[i];
212 op->CustomOutputRegister(output_names[i]);
213 }
214
215 if (cus_output_map_->find(op_type) == cus_output_map_->end()) {
216 (*cus_output_map_)[op_type] = output_map;
217 }
218 return SUCCESS;
219 }
220
GetCustomOpType(const PrimitivePtr & prim) const221 std::string OpAdapterImpl::GetCustomOpType(const PrimitivePtr &prim) const {
222 MS_EXCEPTION_IF_NULL(prim);
223 auto detail_type = GetCustomOpTypeDetail(prim);
224 if (detail_type == CustomOpType::kTbe) {
225 auto func_name = prim->GetAttr("func_name");
226 if (func_name == nullptr) {
227 MS_LOG(ERROR) << "Custom tbe op has no 'func_name' attr.";
228 return "";
229 }
230 return GetValue<std::string>(func_name);
231 }
232 auto value = prim->GetAttr("reg_op_name");
233 if (value == nullptr) {
234 MS_LOG(ERROR) << "Custom op has no reg_op_name attr.";
235 return "";
236 }
237 auto op_type = GetValue<std::string>(value);
238 return op_type;
239 }
240
GenerateCustomOp(const AnfNodePtr anf)241 OperatorPtr OpAdapterImpl::GenerateCustomOp(const AnfNodePtr anf) {
242 MS_EXCEPTION_IF_NULL(anf);
243 auto node = anf->cast<CNodePtr>();
244 if (node == nullptr) {
245 return nullptr;
246 }
247
248 if (node->inputs().empty()) {
249 MS_LOG(EXCEPTION) << "length of node inputs is empty";
250 }
251
252 auto prim = GetValueNode<PrimitivePtr>(node->inputs()[0]);
253 MS_EXCEPTION_IF_NULL(prim);
254 auto op_type = GetCustomOpType(prim);
255 auto op = std::make_shared<::ge::CustomOperator>(node->fullname_with_scope() + op_type, op_type);
256 MS_EXCEPTION_IF_NULL(op);
257 if (GenerateCustomOpInputMap(op, prim) != SUCCESS) {
258 MS_LOG(WARNING) << "Custom op node has no input_names, op[" << prim->name() << "].";
259 }
260
261 if (GenerateCustomOpOutputMap(op, prim) != SUCCESS) {
262 MS_LOG(WARNING) << "Custom op node has no output_names, op[" << prim->name() << "].";
263 }
264
265 auto detail_type = GetCustomOpTypeDetail(prim);
266 if (detail_type == CustomOpType::kAkg) {
267 std::vector<std::string> attr_names{"info_path"};
268 op->CustomRequiredAttrRegister(attr_names[0]);
269 op->CustomInferFuncRegister(CustomAkgOpInferFunc);
270 RegisterCustomOp(prim, op_type, attr_names, true);
271 } else if (detail_type == CustomOpType::kTbe || detail_type == CustomOpType::kAiCpu) {
272 auto attr_names = GetCustomOpKernelAttrs(prim);
273 for (const auto &attr_name : attr_names) {
274 op->CustomRequiredAttrRegister(attr_name);
275 }
276 op->CustomInferFuncRegister(CustomTbeAicpuOpInferFunc);
277 RegisterCustomOp(prim, op_type, attr_names, false);
278 } else {
279 MS_LOG(INFO) << "For custom operators, users need to define and implement the Infershape function by themselves.";
280 }
281
282 return op;
283 }
284
SetOpSubgraphFunc(const OperatorPtr & op,int index,const std::shared_ptr<std::vector<DfGraph>> & branches)285 Status OpAdapterImpl::SetOpSubgraphFunc(const OperatorPtr &op, int index,
286 const std::shared_ptr<std::vector<DfGraph>> &branches) {
287 MS_EXCEPTION_IF_NULL(op);
288 auto it = dyn_subgraph_map_.find(index);
289 if (it != dyn_subgraph_map_.end()) {
290 auto size = branches->size();
291 it->second.create_dyn_subgraph(op, static_cast<unsigned int>(size));
292 for (size_t i = 0; i < size; i++) {
293 it->second.set_subgraph(op, static_cast<unsigned int>(i), std::make_shared<DfGraph>((*branches)[i]));
294 }
295 return SUCCESS;
296 }
297 return NOT_FOUND;
298 }
299
SetOpSubgraphFunc(const OperatorPtr & op,const std::shared_ptr<std::vector<DfGraph>> & subgraphs)300 Status OpAdapterImpl::SetOpSubgraphFunc(const OperatorPtr &op, const std::shared_ptr<std::vector<DfGraph>> &subgraphs) {
301 MS_EXCEPTION_IF_NULL(op);
302 if (subgraph_map_.size() != subgraphs->size()) {
303 return INVALID_ARGUMENT;
304 }
305 for (size_t i = 0; i < subgraphs->size(); i++) {
306 subgraph_map_.at(i).set_subgraph(op, std::make_shared<DfGraph>((*subgraphs)[i]));
307 }
308 return SUCCESS;
309 }
310
SetCustomOpInput(const CusOperatorPtr & op,int index,const OperatorPtr & input) const311 Status OpAdapterImpl::SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) const {
312 MS_EXCEPTION_IF_NULL(op);
313 MS_EXCEPTION_IF_NULL(input);
314 auto it = cus_input_map_->find(op->GetOpType());
315 if (it == cus_input_map_->end()) {
316 return NOT_FOUND;
317 }
318 mindspore::HashMap<int, std::string> &input_map = it->second;
319
320 if ((input_map.find(index) != input_map.end())) {
321 MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << input_map[index];
322 (void)op->SetInput(input_map[index], *input);
323 return SUCCESS;
324 }
325 return NOT_FOUND;
326 }
327
SetNormalOpInput(const OperatorPtr & op,int index,const OperatorPtr & input)328 Status OpAdapterImpl::SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) {
329 MS_EXCEPTION_IF_NULL(op);
330 auto it = input_map_.find(index);
331 if (input != nullptr && it != input_map_.end()) {
332 MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << it->second.name;
333 it->second.set_op(op, input);
334 return SUCCESS;
335 }
336 return NOT_FOUND;
337 }
338
setInput(const OperatorPtr & op,int index,const OperatorPtr & input)339 int OpAdapterImpl::setInput(const OperatorPtr &op, int index, const OperatorPtr &input) {
340 if (IsCustomOp(op)) {
341 auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
342 return static_cast<int>(SetCustomOpInput(cus_op, index, input));
343 } else {
344 return static_cast<int>(SetNormalOpInput(op, index, input));
345 }
346 }
347
SetCustomOpInput(const CusOperatorPtr & op,int index,const OutHandler & handle) const348 Status OpAdapterImpl::SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) const {
349 MS_EXCEPTION_IF_NULL(op);
350 auto it = cus_input_map_->find(op->GetOpType());
351 if (it == cus_input_map_->end()) {
352 return NOT_FOUND;
353 }
354
355 mindspore::HashMap<int, std::string> &input_map = it->second;
356 if ((handle.op != nullptr) && (input_map.find(index) != input_map.end())) {
357 if (handle.out.empty()) {
358 MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << input_map[index];
359 (void)op->SetInput(input_map[index], *(handle.op));
360 } else {
361 MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":"
362 << input_map[index];
363 (void)op->SetInput(input_map[index], *(handle.op), handle.out);
364 }
365 return SUCCESS;
366 }
367 return NOT_FOUND;
368 }
369
SetNormalOpInput(const OperatorPtr & op,int index,const OutHandler & handle)370 Status OpAdapterImpl::SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) {
371 MS_EXCEPTION_IF_NULL(op);
372 auto it = input_map_.find(index);
373 if ((handle.op != nullptr) && (it != input_map_.end())) {
374 if (handle.out.empty()) {
375 MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << it->second.name;
376 it->second.set_op(op, handle.op);
377 } else {
378 MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":"
379 << it->second.name;
380 it->second.set_handle(op, handle);
381 }
382 return SUCCESS;
383 }
384 return NOT_FOUND;
385 }
386
setInput(const OperatorPtr & op,int index,const OutHandler & handle)387 int OpAdapterImpl::setInput(const OperatorPtr &op, int index, const OutHandler &handle) {
388 if (IsCustomOp(op)) {
389 auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
390 return static_cast<int>(SetCustomOpInput(cus_op, index, handle));
391 } else {
392 return static_cast<int>(SetNormalOpInput(op, index, handle));
393 }
394 }
395
setInput(const OperatorPtr & op,int index,const std::shared_ptr<std::vector<OutHandler>> & handler_vec,bool use_create_byindex_func,size_t dyn_index)396 int OpAdapterImpl::setInput(const OperatorPtr &op, int index,
397 const std::shared_ptr<std::vector<OutHandler>> &handler_vec, bool use_create_byindex_func,
398 size_t dyn_index) {
399 MS_EXCEPTION_IF_NULL(handler_vec);
400 if (IsCustomOp(op)) {
401 MS_LOG(ERROR) << "Custom Op do not support dynamic input";
402 return static_cast<int>(FAILED);
403 }
404 MS_EXCEPTION_IF_NULL(op);
405 auto it = dyn_input_map_.find(index);
406 if (it != dyn_input_map_.end()) {
407 if (use_create_byindex_func) {
408 it->second.create_dyn_input_by_index(op, static_cast<unsigned int>(handler_vec->size()), dyn_index);
409 } else {
410 it->second.create_dyn_input(op, static_cast<unsigned int>(handler_vec->size()));
411 }
412 for (unsigned int i = 0; i < handler_vec->size(); ++i) {
413 OutHandler h = (*handler_vec)[i];
414 MS_EXCEPTION_IF_NULL(h.op);
415 if (h.out.empty()) {
416 MS_LOG(DEBUG) << "Link op " << h.op->GetName() << " to " << op->GetName() << ":" << it->second.name;
417 it->second.set_op(op, (i), h.op);
418 } else {
419 MS_LOG(DEBUG) << "Link op " << h.op->GetName() << ":" << h.out << " to " << op->GetName() << ":"
420 << it->second.name;
421 it->second.set_handle(op, i, h);
422 }
423 }
424 return 0;
425 }
426 return static_cast<int>(NOT_FOUND);
427 }
428
getOutput(const OperatorPtr & op,int index)429 OutHandler OpAdapterImpl::getOutput(const OperatorPtr &op, int index) {
430 MS_EXCEPTION_IF_NULL(op);
431 if (IsCustomOp(op)) {
432 return getCustomOutput(op, index);
433 }
434 return getNormalOutput(op, index);
435 }
436
getOutputs(const OperatorPtr & op) const437 std::vector<OutHandler> OpAdapterImpl::getOutputs(const OperatorPtr &op) const {
438 if (IsCustomOp(op)) {
439 return getCustomOutputs(op);
440 }
441 return getNormalOutputs(op);
442 }
443
getCustomOutput(const OperatorPtr & op,int index) const444 OutHandler OpAdapterImpl::getCustomOutput(const OperatorPtr &op, int index) const {
445 MS_EXCEPTION_IF_NULL(op);
446 auto it = cus_output_map_->find(op->GetOpType());
447 if (it == cus_output_map_->end()) {
448 MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT is not supported!";
449 return OutHandler();
450 }
451
452 std::map<int, std::string> &output_map = it->second;
453
454 if ((output_map.find(index) != output_map.end())) {
455 return OutHandler(op, output_map[index]);
456 }
457 MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT index(" << index << ")!";
458 return OutHandler();
459 }
460
getNormalOutput(const OperatorPtr & op,int index)461 OutHandler OpAdapterImpl::getNormalOutput(const OperatorPtr &op, int index) {
462 MS_EXCEPTION_IF_NULL(op);
463 if (!dyn_output_map_.empty() && !output_map_.empty()) {
464 MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!";
465 return OutHandler();
466 }
467 auto it = output_map_.find(index);
468 if (it != output_map_.end()) {
469 return OutHandler(op, it->second.name);
470 } else if (!dyn_output_map_.empty()) {
471 return OutHandler(op, dyn_output_map_.begin()->second.name + std::to_string(index));
472 } else {
473 MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT and DYN_OUTPUT index(" << index << ")!";
474 return OutHandler();
475 }
476 }
477
getNormalOutputs(const OperatorPtr & op) const478 std::vector<OutHandler> OpAdapterImpl::getNormalOutputs(const OperatorPtr &op) const {
479 MS_EXCEPTION_IF_NULL(op);
480 if (!dyn_output_map_.empty() && !output_map_.empty()) {
481 MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!";
482 return std::vector<OutHandler>{};
483 }
484 std::vector<OutHandler> handles;
485 std::transform(output_map_.begin(), output_map_.end(), std::back_inserter(handles),
486 [&op](const auto &item) { return OutHandler(op, item.second.name); });
487 if (!dyn_output_map_.empty()) {
488 auto dyn_output_name = dyn_output_map_.begin()->second.name;
489 auto dyn_output_size = op->GetDynamicOutputNum(dyn_output_name);
490 for (int i = 0; i < dyn_output_size; i++) {
491 handles.emplace_back(OutHandler(op, dyn_output_name + std::to_string(i)));
492 }
493 }
494 return handles;
495 }
496
getCustomOutputs(const OperatorPtr & op) const497 std::vector<OutHandler> OpAdapterImpl::getCustomOutputs(const OperatorPtr &op) const {
498 MS_EXCEPTION_IF_NULL(op);
499 std::vector<OutHandler> handles;
500 auto it = cus_output_map_->find(op->GetOpType());
501 if (it == cus_output_map_->end()) {
502 MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ")'s OUTPUT is not supported!";
503 return handles;
504 }
505 std::transform(it->second.begin(), it->second.end(), std::back_inserter(handles),
506 [&op](const auto &item) { return OutHandler(op, item.second); });
507 return handles;
508 }
509
UpdateSingleOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const std::string & format)510 Status OpAdapterImpl::UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp,
511 const TypePtr &type, const std::string &format) {
512 MS_EXCEPTION_IF_NULL(type);
513
514 auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(shp), type, format);
515 if (desc == nullptr) {
516 MS_LOG(ERROR) << "Update output descriptor failed!";
517 return FAILED;
518 }
519
520 if (IsCustomOp(op)) {
521 if (cus_output_map_->find(op->GetOpType()) == cus_output_map_->end() ||
522 ((*cus_output_map_)[op->GetOpType()].empty())) {
523 MS_LOG(ERROR) << "This op does not create custom output map";
524 return FAILED;
525 }
526 auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
527 MS_EXCEPTION_IF_NULL(cus_op);
528 std::map<int, std::string> output_map = (*cus_output_map_)[op->GetOpType()];
529 (void)cus_op->UpdateOutputDesc(output_map[0], *desc);
530 std::vector<std::vector<int64_t>> out_shapes{desc->GetShape().GetDims()};
531 std::vector<int32_t> out_formats{static_cast<int32_t>(desc->GetFormat())};
532 std::vector<int32_t> out_types{static_cast<int32_t>(desc->GetDataType())};
533 (void)cus_op->SetAttr("output_shapes", out_shapes);
534 (void)cus_op->SetAttr("output_formats", out_formats);
535 (void)cus_op->SetAttr("output_types", out_types);
536 } else {
537 if (!output_map_.empty()) {
538 output_map_.begin()->second.update_out_desc(op, *desc);
539 } else if (!dyn_output_map_.empty()) {
540 dyn_output_map_.begin()->second.update_dyn_output_desc(op, 0, *desc);
541 } else {
542 MS_LOG(DEBUG) << "This op does not have output map";
543 return FAILED;
544 }
545 }
546 return SUCCESS;
547 }
548
GetCustomOpOutputSize(const CusOperatorPtr & cus_op) const549 size_t OpAdapterImpl::GetCustomOpOutputSize(const CusOperatorPtr &cus_op) const {
550 MS_EXCEPTION_IF_NULL(cus_op);
551 if (cus_output_map_->find(cus_op->GetOpType()) == cus_output_map_->end()) {
552 MS_LOG(ERROR) << "This op does not create custom output map";
553 return 0;
554 }
555 size_t output_size = (*cus_output_map_)[cus_op->GetOpType()].size();
556 return output_size;
557 }
558
CreateOutputDesc(const abstract::ShapePtr & shape_ptr,const TypePtr & type,const std::string & format) const559 std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type,
560 const std::string &format) const {
561 if (type == nullptr) {
562 MS_LOG(ERROR) << "Type ptr is nullptr";
563 return nullptr;
564 }
565
566 TypeId me_type = type->type_id();
567 if (kObjectTypeTensorType == me_type) {
568 me_type = dyn_cast<TensorType>(type)->element()->type_id();
569 }
570
571 return TransformUtil::GetGeTensorDesc((shape_ptr == nullptr) ? ShapeVector{} : shape_ptr->shape(), me_type, format);
572 }
573
UpdateMultiOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const std::string & format)574 Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp,
575 const TypePtr &type, const std::string &format) {
576 auto tuple_shp = dyn_cast<abstract::TupleShape>(shp);
577 MS_EXCEPTION_IF_NULL(tuple_shp);
578
579 size_t output_size = 0;
580 bool is_custom_op = IsCustomOp(op);
581 if (is_custom_op) {
582 output_size = GetCustomOpOutputSize(std::dynamic_pointer_cast<CustomOperator>(op));
583 } else {
584 output_size = output_map_.empty()
585 ? static_cast<size_t>(op->GetDynamicOutputNum(dyn_output_map_.begin()->second.name))
586 : output_map_.size();
587 }
588
589 if (output_size == 0) {
590 MS_LOG(DEBUG) << "This op does not have output map";
591 return FAILED;
592 }
593
594 // There are scenarios that output_size is greater than tuple_shape size.
595 // Reserved outputs exist in output_map taking BatchNormGrad as an example.
596 if (output_size < tuple_shp->shape().size()) {
597 MS_LOG(INFO) << "output_map is smaller than tuple_shape size, node: " << op->GetName();
598 return FAILED;
599 }
600
601 std::vector<std::vector<int64_t>> out_shapes;
602 std::vector<int32_t> out_formats;
603 std::vector<int32_t> out_types;
604 for (size_t i = 0; i < tuple_shp->shape().size(); ++i) {
605 auto tuple_type = dyn_cast<Tuple>(type);
606 MS_EXCEPTION_IF_NULL(tuple_type);
607 TypePtr type_elem = tuple_type->elements()[i];
608 if (type_elem == nullptr) {
609 MS_LOG(ERROR) << "Type ptr is nullptr";
610 return FAILED;
611 }
612 TypeId me_type = type_elem->type_id();
613 if (kObjectTypeTensorType == me_type) {
614 me_type = dyn_cast<TensorType>(type_elem)->element()->type_id();
615 }
616 if (me_type == kMetaTypeNone) {
617 continue;
618 }
619
620 auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(tuple_shp->shape()[i]), type_elem, format);
621 if (desc == nullptr) {
622 MS_LOG(WARNING) << "Create op: " << op->GetName() << " output descriptor failed!";
623 return FAILED;
624 }
625
626 if (is_custom_op) {
627 (void)std::dynamic_pointer_cast<CustomOperator>(op)->UpdateOutputDesc((*cus_output_map_)[op->GetOpType()][i],
628 *desc);
629 out_shapes.push_back(desc->GetShape().GetDims());
630 out_formats.push_back(static_cast<int32_t>(desc->GetFormat()));
631 out_types.push_back(static_cast<int32_t>(desc->GetDataType()));
632 } else {
633 auto it = output_map_.find(i);
634 if (it != output_map_.end()) {
635 it->second.update_out_desc(op, *desc);
636 } else if (!dyn_output_map_.empty()) {
637 dyn_output_map_.begin()->second.update_dyn_output_desc(op, static_cast<unsigned int>(i), *desc);
638 }
639 }
640 }
641 if (is_custom_op) {
642 (void)op->SetAttr("output_shapes", out_shapes);
643 (void)op->SetAttr("output_formats", out_formats);
644 (void)op->SetAttr("output_types", out_types);
645 }
646 return SUCCESS;
647 }
648
CreateNodeDesc(const AnfNodePtr & node,const std::string & format) const649 std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &node, const std::string &format) const {
650 MS_EXCEPTION_IF_NULL(node);
651 TypeId me_type = node->Type()->type_id();
652 if (kObjectTypeTensorType == me_type) {
653 me_type = dyn_cast<TensorType>(node->Type())->element()->type_id();
654 }
655 if (me_type <= kNumberTypeBegin || me_type >= kNumberTypeEnd) {
656 return nullptr;
657 }
658
659 std::vector<int64_t> shape;
660 auto shape_ptr = dyn_cast<abstract::Shape>(node->Shape());
661 if (shape_ptr != nullptr) {
662 shape = shape_ptr->shape();
663 }
664
665 auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format);
666 if (desc == nullptr) {
667 MS_LOG(ERROR) << "Update output descriptor failed!";
668 return nullptr;
669 }
670 return desc;
671 }
672
UpdateNormalOpInputDesc(const OperatorPtr & op,const AnfNodePtr & node,const std::string format)673 void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format) {
674 if (op == nullptr) {
675 MS_LOG(ERROR) << "op is nullptr";
676 return;
677 }
678 MS_EXCEPTION_IF_NULL(node);
679 std::map<size_t, size_t> real_input_map;
680 if (!dyn_input_map_.empty() && common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, node->cast<CNodePtr>())) {
681 std::vector<int64_t> dyn_input_sizes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrDynInputSizes);
682 if (!dyn_input_sizes.empty()) {
683 size_t input_index = kIndex1;
684 for (size_t i = 0; i < dyn_input_sizes.size(); ++i) {
685 int64_t dyn_input_size = dyn_input_sizes[i];
686 if (dyn_input_size < 0) {
687 real_input_map[input_index] = i + kIndex1;
688 input_index += 1;
689 } else {
690 input_index += dyn_input_size;
691 }
692 }
693 }
694 }
695
696 auto inputs = node->cast<CNodePtr>()->inputs();
697 for (size_t i = 1; i < inputs.size(); ++i) {
698 size_t real_input_index = i;
699 if (!real_input_map.empty()) {
700 auto iter = real_input_map.find(i);
701 if (iter != real_input_map.end()) {
702 real_input_index = iter->second;
703 } else {
704 continue;
705 }
706 }
707 auto it = input_map_.find(real_input_index);
708 if (it != input_map_.end()) {
709 auto desc = CreateNodeDesc(inputs[i], format);
710 if (desc == nullptr) {
711 continue;
712 }
713
714 it->second.update_input_desc(op, *desc);
715 }
716 }
717 }
718
UpdateCustomOpInputDesc(const CusOperatorPtr & op,const AnfNodePtr & node,const std::string format) const719 void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node,
720 const std::string format) const {
721 if (op == nullptr) {
722 MS_LOG(ERROR) << "op is nullptr";
723 return;
724 }
725 MS_EXCEPTION_IF_NULL(node);
726
727 if (cus_input_map_->find(op->GetOpType()) == cus_input_map_->end() || ((*cus_input_map_)[op->GetOpType()].empty())) {
728 MS_LOG(ERROR) << "This op does not create custom input map";
729 return;
730 }
731
732 mindspore::HashMap<int, std::string> &input_map = (*cus_input_map_)[op->GetOpType()];
733 auto inputs = node->cast<CNodePtr>()->inputs();
734 for (size_t i = 1; i < inputs.size(); ++i) {
735 if (input_map.find(i) != input_map.end()) {
736 auto desc = CreateNodeDesc(inputs[i], format);
737 if (desc == nullptr) {
738 continue;
739 }
740 (void)op->UpdateInputDesc(input_map[i], *desc);
741 }
742 }
743 }
744
updateInputDesc(const OperatorPtr & op,const AnfNodePtr & node)745 void OpAdapterImpl::updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) {
746 MS_EXCEPTION_IF_NULL(op);
747 MS_EXCEPTION_IF_NULL(node);
748 std::string format = GetOpIOFormat(node);
749 if (IsCustomOp(op)) {
750 auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
751 UpdateCustomOpInputDesc(cus_op, node, format);
752 } else {
753 UpdateNormalOpInputDesc(op, node, format);
754 }
755 }
756
updateOutputDesc(const OperatorPtr & op,const abstract::BaseShapePtr & shp,const TypePtr & type,const AnfNodePtr & node)757 void OpAdapterImpl::updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
758 const AnfNodePtr &node) {
759 if (op == nullptr) {
760 MS_LOG(ERROR) << "op is nullptr";
761 return;
762 }
763 MS_EXCEPTION_IF_NULL(node);
764 MS_LOG(DEBUG) << "Op name is " << op->GetName() << " anf is " << node->DebugString() << ", shape: " << shp->ToString()
765 << ", type: " << type;
766 if (AnfUtils::GetOutputTensorNum(node) == 0) {
767 return;
768 }
769
770 auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp);
771 auto no_shape_ptr = dyn_cast<abstract::NoShape>(shp);
772 std::string format = GetOpIOFormat(node);
773
774 if ((normal_shape_ptr != nullptr) || (no_shape_ptr != nullptr)) {
775 if (UpdateSingleOutputDesc(op, shp, type, format) != SUCCESS) {
776 return;
777 }
778 } else if (dyn_cast<abstract::TupleShape>(shp) != nullptr) {
779 if (UpdateMultiOutputDesc(op, shp, type, format) != SUCCESS) {
780 return;
781 }
782 } else {
783 MS_LOG(WARNING) << "Update output desc failed, unknown output shape type";
784 return;
785 }
786 MS_EXCEPTION_IF_NULL(node);
787 if (!node->isa<CNode>()) {
788 return;
789 }
790
791 // Need to update input_desc while the output_desc is updated
792 updateInputDesc(op, node);
793 }
794
setAttr(const OperatorPtr & op,const std::string & attr_key,const ValuePtr & attr_value)795 int OpAdapterImpl::setAttr(const OperatorPtr &op, const std::string &attr_key, const ValuePtr &attr_value) {
796 auto it = attr_map_.find(attr_key);
797 if (it != attr_map_.end()) {
798 // switch case for each avalilable attribute type
799 MS_LOG(DEBUG) << "Op: " << op->GetName() << ", set attr: " << attr_key << "(" << it->second.name
800 << "), value: " << attr_value->ToString();
801 adpt_->AddAttrToDrawGraph(attr_key + std::string("=") + attr_value->ToString());
802 it->second.set_attr(op, attr_value);
803 return 0;
804 }
805 return static_cast<int>(NOT_FOUND);
806 }
807
setAttr(const OperatorPtr & op,const uint32_t & input_idx,const ValuePtr & attr_value)808 int OpAdapterImpl::setAttr(const OperatorPtr &op, const uint32_t &input_idx, const ValuePtr &attr_value) {
809 auto it = input_attr_map_.find(input_idx);
810 if (it != input_attr_map_.end()) {
811 it->second.set_attr(op, attr_value);
812 return static_cast<int>(SUCCESS);
813 }
814 return static_cast<int>(NOT_FOUND);
815 }
816
getAttr(const OperatorPtr & op,const std::string & attr_key,ValuePtr * attr_value)817 int OpAdapterImpl::getAttr(const OperatorPtr &op, const std::string &attr_key, ValuePtr *attr_value) {
818 MS_EXCEPTION_IF_NULL(attr_value);
819 auto it = attr_map_.find(attr_key);
820 if (it != attr_map_.end()) {
821 it->second.get_attr(op, attr_value);
822 return static_cast<int>(SUCCESS);
823 }
824 return static_cast<int>(NOT_FOUND);
825 }
826
getAttr(const OperatorPtr & op,uint32_t input_idx,ValuePtr * attr_value)827 int OpAdapterImpl::getAttr(const OperatorPtr &op, uint32_t input_idx, ValuePtr *attr_value) {
828 MS_EXCEPTION_IF_NULL(attr_value);
829 auto it = input_attr_map_.find(input_idx);
830 if (it != input_attr_map_.end()) {
831 it->second.get_attr(op, attr_value);
832 return static_cast<int>(SUCCESS);
833 }
834 return static_cast<int>(NOT_FOUND);
835 }
836
SetCustomOpAttr(const CusOperatorPtr & op,const PrimitivePtr & prim) const837 int OpAdapterImpl::SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) const {
838 enum ValueType {
839 SINGLE_VALUE = 0,
840 SEQUEUE_VALUE,
841 UNKNOWN_VALUE,
842 };
843
844 MS_EXCEPTION_IF_NULL(prim);
845 MS_EXCEPTION_IF_NULL(op);
846
847 ValueType value_type = SINGLE_VALUE;
848 static std::unordered_set<std::string> excluded_attr{"IsFeatureMapInputList", "IsFeatureMapOutput"};
849 for (auto item : prim->attrs()) {
850 if (excluded_attr.find(item.first) != excluded_attr.end()) {
851 continue;
852 }
853 if (item.second->isa<Int32Imm>()) {
854 (void)op->SetAttr(item.first, GetValue<int32_t>(item.second));
855 } else if (item.second->isa<Int64Imm>()) {
856 (void)op->SetAttr(item.first, GetValue<int64_t>(item.second));
857 } else if (item.second->isa<StringImm>()) {
858 (void)op->SetAttr(item.first, GetValue<std::string>(item.second));
859 } else if (item.second->isa<BoolImm>()) {
860 (void)op->SetAttr(item.first, GetValue<bool>(item.second));
861 } else if (item.second->isa<FP32Imm>()) {
862 (void)op->SetAttr(item.first, GetValue<float>(item.second));
863 } else if (item.second->isa<ValueSequence>()) {
864 value_type = SEQUEUE_VALUE;
865 auto val_seq = item.second->cast<ValueSequencePtr>();
866 if (val_seq->size() == 0) {
867 std::vector<int64_t> value;
868 (void)op->SetAttr(item.first, value);
869 continue;
870 }
871 if ((*val_seq)[0]->isa<StringImm>()) {
872 (void)op->SetAttr(item.first, GetValue<const std::vector<std::string>>(item.second));
873 } else if ((*val_seq)[0]->isa<FP32Imm>()) {
874 (void)op->SetAttr(item.first, GetValue<const std::vector<float>>(item.second));
875 } else if ((*val_seq)[0]->isa<Int32Imm>()) {
876 (void)op->SetAttr(item.first, GetValue<const std::vector<int32_t>>(item.second));
877 } else if ((*val_seq)[0]->isa<Int64Imm>()) {
878 (void)op->SetAttr(item.first, GetValue<const std::vector<int64_t>>(item.second));
879 } else if ((*val_seq)[0]->isa<BoolImm>()) {
880 (void)op->SetAttr(item.first, GetValue<const std::vector<bool>>(item.second));
881 } else {
882 MS_LOG(EXCEPTION) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name()
883 << ", attr name: " << item.first << ", value: " << item.second->ToString();
884 }
885 } else {
886 MS_LOG(WARNING) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name()
887 << ", attr name: " << item.first << ", value: " << item.second->ToString();
888 return static_cast<int>(NOT_FOUND);
889 }
890
891 if (value_type == SINGLE_VALUE) {
892 adpt_->AddAttrToDrawGraph(item.first + std::string("=") + item.second->ToString());
893 } else if (value_type == SEQUEUE_VALUE) {
894 adpt_->AddAttrToDrawGraph(item.first + std::string("=") + "[...]");
895 }
896 }
897 return 0;
898 }
899
GetOpAttrList(const OperatorPtr & op) const900 std::map<std::string, ValuePtr> OpAdapterImpl::GetOpAttrList(const OperatorPtr &op) const {
901 std::map<std::string, ValuePtr> attr_list;
902 for (auto &it : attr_map_) {
903 ValuePtr value = nullptr;
904 it.second.get_attr(op, &value);
905 (void)attr_list.emplace(it.second.name, value);
906 }
907 for (auto &it : input_attr_map_) {
908 ValuePtr value = nullptr;
909 it.second.get_attr(op, &value);
910 (void)attr_list.emplace(it.second.name, value);
911 }
912 return attr_list;
913 }
914
GetNormalOpAttrList(const OperatorPtr & op,const AnfNodePtr & node) const915 std::map<std::string, ValuePtr> OpAdapterImpl::GetNormalOpAttrList(const OperatorPtr &op,
916 const AnfNodePtr &node) const {
917 MS_EXCEPTION_IF_NULL(node);
918 if (!node->isa<CNode>() || node->cast<CNodePtr>() == nullptr) {
919 return {};
920 }
921 auto cnode = node->cast<CNodePtr>();
922 auto &inputs = cnode->inputs();
923 if (inputs.empty() || !IsValueNode<Primitive>(inputs[0])) {
924 return {};
925 }
926
927 auto prim = GetValueNode<PrimitivePtr>(inputs[0]);
928 std::map<std::string, ValuePtr> attr_list;
929 for (auto &it : attr_map_) {
930 // set attr from extra_attr
931 auto it_extra = extra_attr_->find(it.first);
932 if (it_extra != extra_attr_->end()) {
933 auto value = it_extra->second;
934 (void)attr_list.emplace(it.second.name, value);
935 } else {
936 auto value = prim->GetAttr(it.first);
937 it.second.get_attr(op, &value);
938 (void)attr_list.emplace(it.second.name, value);
939 }
940 }
941
942 auto real_input_indices = GetRealInputIndices(cnode);
943 // set attr from const input
944 for (auto &it : input_attr_map_) {
945 size_t cur_idx = GetRealAnfInputIndex(real_input_indices, it.first);
946 if (inputs.size() <= cur_idx || !inputs[cur_idx]->isa<ValueNode>()) {
947 continue;
948 }
949 auto const_value = GetValueNode(inputs[cur_idx]);
950 MS_LOG(DEBUG) << "Get input attr: input_" << cur_idx << "(" << it.second.name
951 << "), value: " << const_value->ToString();
952 if (const_value->isa<None>()) {
953 continue;
954 }
955 (void)attr_list.emplace(it.second.name, const_value);
956 }
957
958 // Get need convert to input's attr
959 for (auto &it : attr_input_map_) {
960 auto value = prim->GetAttr(it.first);
961 if (value == nullptr) {
962 continue;
963 }
964 (void)attr_list.emplace(it.first, value);
965 }
966 return attr_list;
967 }
968
SetNormalOpAttr(const OperatorPtr & op,const PrimitivePtr & prim)969 int OpAdapterImpl::SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) {
970 MS_EXCEPTION_IF_NULL(prim);
971 MS_EXCEPTION_IF_NULL(op);
972 for (auto &it : attr_map_) {
973 if (attr_input_map_.count(it.first) != 0) {
974 MS_LOG(WARNING) << "Attr: " << it.first << " will convert to input, please del it from ATTR_MAP.";
975 continue;
976 }
977 auto value = prim->GetAttr(it.first);
978 if (value != nullptr) {
979 // convert parts of attr to str eg. data_format or change ir attr to op attr eg. axis[0]
980 (void)CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), it.first, &value);
981 (void)CheckAndConvertUtils::CheckIrAttrtoOpAttr(prim->name(), it.first, &value);
982 // set attr from primitive
983 int ret = setAttr(op, it.first, value);
984 if (ret != 0) {
985 return ret;
986 }
987 } else {
988 // set attr from extra_attr
989 auto it_extra = extra_attr_->find(it.first);
990 if (it_extra != extra_attr_->end()) {
991 int ret = setAttr(op, it.first, it_extra->second);
992 if (ret != 0) {
993 return ret;
994 }
995 }
996 }
997 }
998 return 0;
999 }
1000
SetNoFoldingOpAttr(const OperatorPtr & op,const PrimitivePtr & prim)1001 int OpAdapterImpl::SetNoFoldingOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) {
1002 MS_EXCEPTION_IF_NULL(prim);
1003 MS_EXCEPTION_IF_NULL(op);
1004 op->SetAttr("no_need_constant_folding", true);
1005 return SetNormalOpAttr(op, prim);
1006 }
1007
setAttr(const OperatorPtr & op,const PrimitivePtr & prim)1008 int OpAdapterImpl::setAttr(const OperatorPtr &op, const PrimitivePtr &prim) {
1009 int ret = 0;
1010 if (IsCustomPrim(prim)) {
1011 auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
1012 ret = SetCustomOpAttr(cus_op, prim);
1013 } else if (IsNoNeedConstantFoldCNode(prim)) {
1014 ret = SetNoFoldingOpAttr(op, prim);
1015 } else {
1016 ret = SetNormalOpAttr(op, prim);
1017 }
1018 return ret;
1019 }
1020
setAttr(const OperatorPtr & op,const AnfNodePtr & node)1021 int OpAdapterImpl::setAttr(const OperatorPtr &op, const AnfNodePtr &node) {
1022 // no attribute for lonely node
1023 MS_EXCEPTION_IF_NULL(node);
1024 if (!node->isa<CNode>() || node->cast<CNodePtr>() == nullptr) {
1025 return 0;
1026 }
1027
1028 auto cnode = node->cast<CNodePtr>();
1029 auto &inputs = cnode->inputs();
1030 if (inputs.empty()) {
1031 return 0;
1032 }
1033
1034 // get Attr T from abstract of anfnode first,
1035 // if attr "T" appears in primitive, the primitive T will cover this one
1036 if (attr_map_.find("T") != attr_map_.end()) {
1037 // get dtype from inputs[1], if the node has no inputs, set the attr T with output dtype
1038 TypePtr type;
1039 if (inputs.size() > 1) {
1040 type = inputs[1]->Type();
1041 } else {
1042 type = node->Type();
1043 }
1044 if (type != nullptr) {
1045 (void)setAttr(op, "T", MakeValue(type));
1046 }
1047 }
1048
1049 // set attr from primitive and ExtraAttr
1050 if (IsValueNode<Primitive>(inputs[0])) {
1051 // set attr from primitive
1052 PrimitivePtr prim = GetValueNode<PrimitivePtr>(inputs[0]);
1053 int ret = setAttr(op, prim);
1054 if (ret != 0) {
1055 return ret;
1056 }
1057 }
1058
1059 auto real_input_indices = GetRealInputIndices(cnode);
1060 // set attr from const input
1061 for (auto &it : input_attr_map_) {
1062 size_t cur_idx = GetRealAnfInputIndex(real_input_indices, it.first);
1063 if (inputs.size() <= cur_idx || !inputs[cur_idx]->isa<ValueNode>()) {
1064 continue;
1065 }
1066
1067 auto const_value = GetValueNode(inputs[cur_idx]);
1068 MS_LOG(INFO) << "Set attr: input_" << cur_idx << "(" << it.second.name << "), value: " << const_value->ToString();
1069 if (const_value->isa<None>()) {
1070 continue;
1071 }
1072 if (const_value->isa<mindspore::tensor::Tensor>()) {
1073 auto tensorptr = const_value->cast<mindspore::tensor::TensorPtr>();
1074 const_value = CreateValueFromTensor(tensorptr);
1075 }
1076
1077 adpt_->AddAttrToDrawGraph(it.second.name + std::string("=") + const_value->ToString());
1078 it.second.set_attr(op, const_value);
1079 }
1080 return 0;
1081 }
1082 } // namespace transform
1083 } // namespace mindspore
1084