• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/ascend/kernel_select_ascend.h"
18 #include "backend/session/anf_runtime_algorithm.h"
19 #include "runtime/device/kernel_info.h"
20 #include "ir/func_graph.h"
21 #include "backend/kernel_compiler/common_utils.h"
22 #include "backend/kernel_compiler/kernel_query.h"
23 #include "backend/kernel_compiler/kernel_build_info.h"
24 
25 namespace mindspore {
26 namespace device {
27 namespace ascend {
28 namespace {
29 // sort format according the number of occurrences.
cmp_format_num(const std::pair<std::string,size_t> & a,const std::pair<std::string,size_t> & b)30 bool cmp_format_num(const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) {
31   if (a.second != b.second) {
32     return a.second > b.second;
33   } else if (a.first == kOpFormat_DEFAULT) {
34     return a.second + 1 > b.second;
35   } else if (b.first == kOpFormat_DEFAULT) {
36     return a.second > b.second + 1;
37   }
38   return a.second > b.second;
39 }
40 
GetPrimitivePrecision(const CNodePtr & cnode)41 TypeId GetPrimitivePrecision(const CNodePtr &cnode) {
42   auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
43   MS_EXCEPTION_IF_NULL(primitive);
44 
45   TypeId except_type = kTypeUnknown;
46   if (primitive->GetAttr(kAttrFixPrecision) != nullptr) {
47     auto strExceptDtype = GetValue<std::string>(primitive->GetAttr(kAttrFixPrecision));
48     if (strExceptDtype == "float16") {
49       except_type = kNumberTypeFloat16;
50     } else if (strExceptDtype == "float32") {
51       except_type = kNumberTypeFloat32;
52     } else {
53       MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got" << strExceptDtype;
54     }
55   }
56 
57   return except_type;
58 }
59 }  // namespace
60 
ResetKernelBuildInfo(const CNodePtr & kernel_node)61 void ResetKernelBuildInfo(const CNodePtr &kernel_node) {
62   size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
63   for (size_t input_index = 0; input_index < input_num; ++input_index) {
64     auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
65     MS_EXCEPTION_IF_NULL(input_kernel_node);
66     auto kernel_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
67     if (!kernel::IsWeightBoundary(kernel_with_index.first)) {
68       continue;
69     }
70     // reset format and dtype.
71     kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
72     builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
73     builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown});
74     AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_kernel_node.get());
75   }
76 }
77 
UpdateKernelInfo(const std::vector<AnfNodePtr> & node_list)78 void UpdateKernelInfo(const std::vector<AnfNodePtr> &node_list) {
79   for (size_t i = 0; i < node_list.size(); ++i) {
80     // select nodes in subgraph.
81     auto anf_node = node_list[i];
82     MS_EXCEPTION_IF_NULL(anf_node);
83     auto cnode = anf_node->cast<CNodePtr>();
84     MS_EXCEPTION_IF_NULL(cnode);
85     auto fix_precision_type = GetPrimitivePrecision(cnode);
86     if (fix_precision_type != kTypeUnknown) {
87       std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
88       kernel::KernelQuery(cnode, &kernel_info_list, KernelType::AKG_KERNEL);
89 
90       for (size_t index = 0; index < kernel_info_list.size(); ++index)
91         // only math the first input
92         if (kernel_info_list[index]->GetInputDeviceType(0) == fix_precision_type &&
93             kernel_info_list[index]->GetInputFormat(0) == AnfAlgo::GetPrevNodeOutputFormat(cnode, 0) &&
94             AnfAlgo::GetInputDeviceDataType(cnode, 0) != fix_precision_type) {
95           auto selected_kernel_info_ptr = kernel_info_list[index];
96           ResetKernelBuildInfo(cnode);
97           AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, cnode.get());
98           SetTensorDeviceInfo(cnode);
99           break;
100         }
101     }
102   }
103 }
104 
CanConvertDefaultShapeToNZ(const std::vector<size_t> & shape)105 bool CanConvertDefaultShapeToNZ(const std::vector<size_t> &shape) {
106   for (size_t i = 1; i <= shape.size(); ++i) {
107     if (i > 2) {
108       break;
109     }
110     if (shape[shape.size() - i] != 1 && shape[shape.size() - i] % kCubeSize != 0) {
111       return false;
112     }
113   }
114   return true;
115 }
116 
DefaultToFracNZAxis(const std::vector<size_t> & ori_shape,const std::vector<int64_t> & axis)117 std::vector<int64_t> DefaultToFracNZAxis(const std::vector<size_t> &ori_shape, const std::vector<int64_t> &axis) {
118   std::vector<int64_t> frac_nz_axis = axis;
119   auto shape_len = SizeToLong(ori_shape.size());
120   for (size_t i = 0; i < axis.size(); ++i) {
121     auto axis_idx = (frac_nz_axis[i] + shape_len) % shape_len;
122     if (axis_idx == shape_len - SizeToLong(kIndex1)) {
123       frac_nz_axis[i] = axis_idx - SizeToLong(kIndex1);
124       frac_nz_axis.push_back(axis_idx + SizeToLong(kIndex2));
125     } else if (axis_idx == shape_len - SizeToLong(kIndex2)) {
126       frac_nz_axis[i] = axis_idx + SizeToLong(kIndex1);
127       frac_nz_axis.push_back(axis_idx + SizeToLong(kIndex2));
128     } else {
129       frac_nz_axis[i] = axis_idx;
130     }
131   }
132   return frac_nz_axis;
133 }
134 
GetReducedFracNZShape(const std::vector<size_t> & ori_shape,const std::vector<int64_t> & axis,bool keep_dims)135 std::vector<size_t> GetReducedFracNZShape(const std::vector<size_t> &ori_shape, const std::vector<int64_t> &axis,
136                                           bool keep_dims) {
137   std::vector<size_t> result;
138   std::set<size_t> positive_idx;
139   for (const auto &a : axis) {
140     positive_idx.insert(a >= 0 ? LongToSize(a) : ori_shape.size() + LongToSize(a));
141   }
142   for (size_t i = 0; i < ori_shape.size(); ++i) {
143     if (positive_idx.count(i) == 0) {
144       result.push_back(ori_shape[i]);
145     } else if (keep_dims) {
146       result.push_back(1);
147     }
148   }
149   return result;
150 }
151 
UpdateFracNZReduceOp(const CNodePtr & cnode)152 void UpdateFracNZReduceOp(const CNodePtr &cnode) {
153   MS_EXCEPTION_IF_NULL(cnode);
154   auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0);
155   if (input_format == kOpFormat_FRAC_NZ) {
156     // Clone primitive to modify it
157     auto prim = GetCNodePrimitive(cnode);
158     auto new_prim = std::make_shared<Primitive>(*prim);
159     auto new_prim_node = NewValueNode(new_prim);
160     cnode->set_input(0, new_prim_node);
161 
162     auto axis_value = new_prim->GetAttr(kAttrAxis);
163     std::vector<int64_t> default_axis;
164     if (axis_value->isa<ValueList>()) {
165       auto value_list = dyn_cast<ValueList>(axis_value);
166       for (const auto &item : value_list->value()) {
167         if (item->isa<Int64Imm>()) {
168           default_axis.push_back(GetValue<int64_t>(item));
169         } else {
170           MS_LOG(EXCEPTION) << "GetValue type should be int64";
171         }
172       }
173     } else if (axis_value->isa<ValueTuple>()) {
174       auto value_tuple = dyn_cast<ValueTuple>(axis_value);
175       for (const auto &item : value_tuple->value()) {
176         if (item->isa<Int64Imm>()) {
177           default_axis.push_back(GetValue<int64_t>(item));
178         } else {
179           MS_LOG(EXCEPTION) << "GetValue type should be int64";
180         }
181       }
182     } else {
183       MS_LOG(ERROR) << "Axis attr type is not correct!";
184     }
185     auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
186     std::vector<int64_t> frac_nz_axis = DefaultToFracNZAxis(infer_shape, default_axis);
187     AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<std::vector<int64_t>>(frac_nz_axis), cnode);
188     auto output_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
189     if (output_shape.size() == 1) {
190       AnfAlgo::SetNodeAttr(kAttrOutputDefault, MakeValue<bool>(true), cnode);
191     }
192   }
193 }
194 
GetDefaultFormat(const CNodePtr & kernel_node,std::string * default_format,bool * use_same_format)195 void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, bool *use_same_format) {
196   MS_EXCEPTION_IF_NULL(kernel_node);
197   MS_EXCEPTION_IF_NULL(default_format);
198   MS_EXCEPTION_IF_NULL(use_same_format);
199   std::unordered_map<std::string, size_t> all_input_formats;
200   size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
201   for (size_t i = 0; i < input_num; ++i) {
202     auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
203     MS_EXCEPTION_IF_NULL(input_kernel_node);
204     if (!input_kernel_node->isa<Parameter>()) {
205       ++all_input_formats[AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)];
206       continue;
207     }
208     auto para = input_kernel_node->cast<ParameterPtr>();
209     if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) {
210       ++all_input_formats[AnfAlgo::GetOutputFormat(para, 0)];
211       continue;
212     }
213     *use_same_format = false;
214   }
215 
216   if (all_input_formats.empty()) {
217     // all inputs are parameter.
218     *default_format = kOpFormat_NC1HWC0;
219   } else {
220     std::vector<std::pair<std::string, size_t>> pairs;
221     for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) {
222       pairs.emplace_back(std::make_pair(iter->first, iter->second));
223     }
224 
225     std::sort(pairs.begin(), pairs.end(), cmp_format_num);
226     *default_format = pairs.begin()->first;
227   }
228 
229   for (size_t i = 0; i < input_num; ++i) {
230     auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
231     MS_EXCEPTION_IF_NULL(input_kernel_node);
232     if (!input_kernel_node->isa<Parameter>() ||
233         AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) != kTypeUnknown) {
234       continue;
235     }
236     auto weight_infer_shape = AnfAlgo::GetOutputInferShape(input_kernel_node, 0);
237     if (weight_infer_shape.size() < kShape2dDims && *default_format == kOpFormat_FRAC_NZ) {
238       *default_format = kOpFormat_DEFAULT;
239       *use_same_format = true;
240       break;
241     }
242   }
243 }
244 
UpdateInputsKernelInfo(const CNodePtr & kernel_node,const std::vector<AnfNodePtr> & input_list,const std::string & default_format,bool use_same_format,std::vector<std::string> * graph_input_format,std::vector<TypeId> * graph_input_type)245 void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list,
246                             const std::string &default_format, bool use_same_format,
247                             std::vector<std::string> *graph_input_format, std::vector<TypeId> *graph_input_type) {
248   MS_EXCEPTION_IF_NULL(graph_input_format);
249   MS_EXCEPTION_IF_NULL(graph_input_type);
250   // We set same format to all inputs of graph kernel subgraph, and process this latter.
251   // We set dtype to inputs of graph kernel subgraph same as infer dtypes.
252   size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
253   for (size_t i = 0; i < input_num; ++i) {
254     auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
255     MS_EXCEPTION_IF_NULL(input_kernel_node);
256     if (use_same_format) {
257       bool can_convert = true;
258       if (default_format == kOpFormat_FRAC_NZ) {
259         auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
260         if (!CanConvertDefaultShapeToNZ(infer_shape)) {
261           MS_LOG(WARNING) << "Shape can't be converted to frac nz shape, so use default format instead";
262           can_convert = false;
263         }
264       }
265       if (can_convert) {
266         graph_input_format->emplace_back(default_format);
267       } else {
268         graph_input_format->emplace_back(kOpFormat_DEFAULT);
269       }
270       graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i));
271       continue;
272     }
273 
274     if (!input_kernel_node->isa<Parameter>()) {
275       // subgraph parameter from output of other nodes.
276       graph_input_format->push_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i));
277       graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i));
278       continue;
279     }
280 
281     auto para = input_kernel_node->cast<ParameterPtr>();
282     MS_EXCEPTION_IF_NULL(para);
283     if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) {
284       // parameter already selected.
285       graph_input_format->push_back(AnfAlgo::GetOutputFormat(para, 0));
286       graph_input_type->push_back(AnfAlgo::GetOutputDeviceDataType(para, 0));
287       continue;
288     }
289 
290     // weight parameter.
291     graph_input_format->push_back(default_format);
292     graph_input_type->push_back(AnfAlgo::GetOutputInferDataType(input_kernel_node, 0));
293   }
294 
295   for (size_t i = 0; i < input_num; ++i) {
296     kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
297     std::vector<std::string> outputs_format = {(*graph_input_format)[i]};
298     std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]};
299     builder.SetOutputsFormat(outputs_format);
300     builder.SetOutputsDeviceType(outputs_device_type);
301     AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
302   }
303 }
304 
UpdateEquivFormat(const std::vector<AnfNodePtr> & node_list,const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & mng)305 void UpdateEquivFormat(const std::vector<AnfNodePtr> &node_list, const FuncGraphPtr &func_graph,
306                        const FuncGraphManagerPtr &mng) {
307   MS_EXCEPTION_IF_NULL(mng);
308   for (size_t i = 0; i < node_list.size(); ++i) {
309     // select nodes in subgraph.
310     auto anf_node = node_list[i];
311     MS_EXCEPTION_IF_NULL(anf_node);
312     auto cnode = anf_node->cast<CNodePtr>();
313     MS_EXCEPTION_IF_NULL(cnode);
314     cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
315     SelectKernelInfo(cnode, KernelType::AKG_KERNEL);
316     // Update ReduceSum
317     if (!IsPrimitiveCNode(cnode, prim::kPrimReduceSum)) {
318       continue;
319     }
320     UpdateFracNZReduceOp(cnode);
321     // If ReduceSum's output is 1d and not Default format, convert it to Default format
322     auto out_format = AnfAlgo::GetOutputFormat(cnode, 0);
323     if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) {
324       continue;
325     }
326     // Insert EquivFormat node, then select kernel info again
327     std::vector<AnfNodePtr> trans_inputs;
328     trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat));
329     trans_inputs.push_back(cnode);
330     CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
331     AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0)},
332                                         {AnfAlgo::GetOutputInferShape(cnode, 0)}, trans_node.get());
333     AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue<std::vector<std::string>>({"x"}), trans_node);
334 
335     if (trans_node->kernel_info() == nullptr) {
336       trans_node->set_kernel_info(std::make_shared<device::KernelInfo>());
337     }
338     SelectKernelInfo(trans_node, KernelType::AKG_KERNEL);
339     mng->Replace(cnode, trans_node);
340   }
341 }
342 
CheckFormatsAndDtypes(const CNodePtr & kernel_node,const std::vector<AnfNodePtr> & input_list,const FuncGraphManagerPtr & mng,const std::string & default_format,std::vector<std::string> * graph_input_format,std::vector<TypeId> * graph_input_type,std::vector<bool> * need_update)343 void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list,
344                            const FuncGraphManagerPtr &mng, const std::string &default_format,
345                            std::vector<std::string> *graph_input_format, std::vector<TypeId> *graph_input_type,
346                            std::vector<bool> *need_update) {
347   MS_EXCEPTION_IF_NULL(kernel_node);
348   MS_EXCEPTION_IF_NULL(mng);
349   MS_EXCEPTION_IF_NULL(graph_input_format);
350   MS_EXCEPTION_IF_NULL(graph_input_type);
351   MS_EXCEPTION_IF_NULL(need_update);
352   // check graph input format and dtype use inner ops.
353   size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
354   if (graph_input_format->size() != input_num || graph_input_type->size() != input_num ||
355       need_update->size() != input_num) {
356     MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString()
357                       << "], [" << graph_input_format->size() << "] != [" << input_num << "]";
358   }
359   auto &node_users = mng->node_users();
360   for (size_t i = 0; i < input_num; ++i) {
361     auto &input = input_list[i];
362     auto iter = node_users.find(input);
363     if (iter == node_users.end() || iter->second.empty()) {
364       continue;
365     }
366     for (auto &node_user : iter->second) {
367       if (node_user.first->kernel_info() == nullptr || !node_user.first->kernel_info()->has_build_info()) {
368         // maybe not a real kernel.
369         continue;
370       }
371       auto user_format = AnfAlgo::GetInputFormat(node_user.first, IntToSize(node_user.second - 1));
372       if (user_format != (*graph_input_format)[i]) {
373         MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString() << " of ["
374                         << kernel_node->DebugString()
375                         << "] selected different format. we use default: " << default_format;
376         (*graph_input_format)[i] = default_format;
377         (*need_update)[i] = true;
378       }
379 
380       if (kernel_node->input(i + 1)->isa<Parameter>() ||
381           AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)) == (*graph_input_type)[i]) {
382         continue;
383       }
384 
385       TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0);
386       MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString() << " of ["
387                       << kernel_node->DebugString()
388                       << "] selected different dtype. we use default: " << TypeIdLabel(default_dtype);
389       (*graph_input_type)[i] = default_dtype;
390       (*need_update)[i] = true;
391     }
392   }
393 }
394 
UpdateFormatsAndDtypes(const CNodePtr & kernel_node,const std::vector<AnfNodePtr> & node_list,const std::vector<AnfNodePtr> & input_list,const std::vector<bool> & need_update,const std::vector<std::string> & graph_input_format,const std::vector<TypeId> & graph_input_type)395 void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list,
396                             const std::vector<AnfNodePtr> &input_list, const std::vector<bool> &need_update,
397                             const std::vector<std::string> &graph_input_format,
398                             const std::vector<TypeId> &graph_input_type) {
399   MS_EXCEPTION_IF_NULL(kernel_node);
400   // update graph input format and dtype use inner ops.
401   size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
402   if (graph_input_format.size() != input_num || graph_input_type.size() != input_num ||
403       need_update.size() != input_num) {
404     MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString()
405                       << "], [" << graph_input_format.size() << "] != [" << input_num << "]";
406   }
407   for (size_t i = 0; i < input_num; ++i) {
408     if (!need_update[i]) {
409       continue;
410     }
411 
412     MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString()
413                   << "] to: " << graph_input_format[i];
414     MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString()
415                   << "] to: " << TypeIdLabel(graph_input_type[i]);
416     kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
417     std::vector<std::string> outputs_format = {graph_input_format[i]};
418     std::vector<TypeId> outputs_device_type = {graph_input_type[i]};
419     builder.SetOutputsFormat(outputs_format);
420     builder.SetOutputsDeviceType(outputs_device_type);
421     AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
422   }
423 
424   ResetKernelBuildInfo(kernel_node);
425   // select nodes in subgraph again.
426   for (size_t i = 0; i < node_list.size(); ++i) {
427     auto anf_node = node_list[i];
428     MS_EXCEPTION_IF_NULL(anf_node);
429     auto cnode = anf_node->cast<CNodePtr>();
430     MS_EXCEPTION_IF_NULL(cnode);
431     kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
432     size_t cnode_input_num = AnfAlgo::GetInputTensorNum(cnode);
433     for (size_t j = 0; j < cnode_input_num; ++j) {
434       auto input_node = cnode->input(j + 1);
435       MS_EXCEPTION_IF_NULL(input_node);
436       if (!IsValueNode<tensor::Tensor>(input_node)) {
437         continue;
438       }
439       // reset format and dtype of const tensor.
440       builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
441       builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown});
442       AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_node.get());
443     }
444     SelectKernelInfo(node_list[i]->cast<CNodePtr>(), KernelType::AKG_KERNEL);
445   }
446 }
447 
SetGraphKernelInfo(const CNodePtr & kernel_node,const std::vector<std::pair<AnfNodePtr,size_t>> & output_index,const std::vector<std::string> & graph_input_format,const std::vector<TypeId> & graph_input_type)448 void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair<AnfNodePtr, size_t>> &output_index,
449                         const std::vector<std::string> &graph_input_format,
450                         const std::vector<TypeId> &graph_input_type) {
451   MS_EXCEPTION_IF_NULL(kernel_node);
452   std::vector<std::string> graph_output_format;
453   std::vector<TypeId> graph_output_type;
454   for (size_t i = 0; i < output_index.size(); ++i) {
455     auto const &output = output_index[i];
456     graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second));
457     TypeId output_type(kTypeUnknown);
458     if (output.first->isa<CNode>()) {
459       output_type = AnfAlgo::GetCNodeOutputPrecision(output.first);
460     }
461     if (output_type == kTypeUnknown) {
462       output_type = AnfAlgo::GetOutputDeviceDataType(output.first, output.second);
463     }
464     graph_output_type.push_back(output_type);
465   }
466 
467   kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
468   graph_info_builder.SetInputsFormat(graph_input_format);
469   graph_info_builder.SetInputsDeviceType(graph_input_type);
470   graph_info_builder.SetOutputsFormat(graph_output_format);
471   graph_info_builder.SetOutputsDeviceType(graph_output_type);
472   graph_info_builder.SetProcessor(kernel::Processor::AICORE);
473   graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
474   graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
475   auto graph_selected_info = graph_info_builder.Build();
476   MS_EXCEPTION_IF_NULL(graph_selected_info);
477   AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());
478   SetTensorDeviceInfo(kernel_node);
479 }
480 
SelectGraphKernelInfo(const CNodePtr & kernel_node,const FuncGraphPtr & func_graph)481 void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) {
482   MS_EXCEPTION_IF_NULL(kernel_node);
483   MS_EXCEPTION_IF_NULL(func_graph);
484 
485   // collect input info of funcgraph
486   std::vector<AnfNodePtr> node_list;
487   std::vector<AnfNodePtr> input_list;
488   std::vector<AnfNodePtr> output_list;
489   kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
490   if (input_list.size() != kernel_node->inputs().size() - 1) {
491     MS_EXCEPTION(ArgumentError) << "Input num of funcgraph[" << func_graph->ToString() << "] not equal input of cnode["
492                                 << kernel_node->DebugString() << "], [%" << input_list.size() << "] != ["
493                                 << kernel_node->inputs().size() << "]";
494   }
495 
496   std::string default_format;
497   bool use_same_format = true;
498   GetDefaultFormat(kernel_node, &default_format, &use_same_format);
499   MS_LOG(DEBUG) << "GraphKernel[" << func_graph->ToString() << "] use same input format[" << default_format
500                 << "] for ParameterWeight.";
501 
502   std::vector<std::string> graph_input_format;
503   std::vector<TypeId> graph_input_type;
504   UpdateInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format,
505                          &graph_input_type);
506 
507   auto mng = func_graph->manager();
508   if (mng == nullptr) {
509     mng = Manage(func_graph, true);
510   }
511   UpdateEquivFormat(node_list, func_graph, mng);
512   node_list.clear();
513   input_list.clear();
514   output_list.clear();
515   kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
516 
517   // update graph input format and dtype use inner ops.
518   std::vector<bool> need_update(AnfAlgo::GetInputTensorNum(kernel_node), false);
519   CheckFormatsAndDtypes(kernel_node, input_list, mng, default_format, &graph_input_format, &graph_input_type,
520                         &need_update);
521   UpdateFormatsAndDtypes(kernel_node, node_list, input_list, need_update, graph_input_format, graph_input_type);
522 
523   // set fix_precision for kernel when the me prim has fix_precision attr
524   UpdateKernelInfo(node_list);
525 
526   auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
527   SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type);
528 }
529 }  // namespace ascend
530 }  // namespace device
531 }  // namespace mindspore
532