• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/ascend/kernel_select_ascend.h"
18 
19 #include <algorithm>
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27 #include "backend/kernel_compiler/kernel_build_info.h"
28 #include "backend/kernel_compiler/kernel_query.h"
29 #include "backend/kernel_compiler/oplib/oplib.h"
30 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
31 #include "backend/session/anf_runtime_algorithm.h"
32 #include "common/trans.h"
33 #include "debug/anf_ir_dump.h"
34 #include "frontend/operator/ops.h"
35 #include "utils/ms_context.h"
36 #include "utils/ms_utils.h"
37 #include "utils/trace_base.h"
38 namespace mindspore {
39 namespace device {
40 namespace ascend {
41 namespace {
42 const int kWeightUnInitScore = 1;
43 const int kWeightInitScore = 2;
44 const int kFeatureMapBaseScore = 10;
45 constexpr auto kPriChoosenFormat = "pri_format";
46 enum MatchCountPriority : int {
47   MATCH_COUNT_PRIORITY_BEGIN = 0,
48   MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
49   MATCH_FORMAT_COUNT,
50   MATCH_SPECIAL_FORMAT_COUNT,
51   MATCH_DEFAULT_FORMAT_COUNT,
52   MATCH_OUTPUT_DTYPE_COUNT,
53   MATCH_COUNT_PRIORITY_END
54 };
55 const std::map<std::string, std::vector<std::string>> kNextOpFormatList = {
56   {prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}}};
57 
MatchInferOutputDataType(const CNodePtr & cnode,const kernel::KernelBuildInfo & kernel_build_info)58 bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
59   MS_EXCEPTION_IF_NULL(cnode);
60   // Check input data type
61   for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
62     TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
63     if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
64       return false;
65     }
66   }
67   // Check output data type
68   for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
69     if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) {
70       return false;
71     }
72   }
73   return true;
74 }
75 
GetPriorityMatchFormat(const CNodePtr & cnode)76 string GetPriorityMatchFormat(const CNodePtr &cnode) {
77   constexpr size_t k5dSize = 5;
78   constexpr size_t k4dSize = 4;
79   string priority_matched_format = kOpFormat_NC1HWC0;
80   bool is_init = false;
81   bool need_change_nd = false;
82   bool is_5d_input = false;
83   size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
84   for (size_t index = 0; index < input_num; ++index) {
85     auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
86     if (AnfAlgo::IsFeatureMapInput(cnode, index) &&
87         kHWSpecialFormatSet.find(pre_output_format) != kHWSpecialFormatSet.end()) {
88       priority_matched_format = !is_init ? pre_output_format : priority_matched_format;
89       is_init = true;
90     }
91     // feature map has two or more special format;
92     if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) {
93       priority_matched_format = kOpFormat_DEFAULT;
94     }
95     auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size();
96     if (input_shape_size == k5dSize) {
97       is_5d_input = true;
98     }
99     need_change_nd = (need_change_nd || (input_shape_size != k4dSize && input_shape_size > 1));
100   }
101   if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) {
102     priority_matched_format = kOpFormat_DEFAULT;
103   }
104   if (is_5d_input && priority_matched_format != kOpFormat_FRAC_NZ) {
105     priority_matched_format = kOpFormat_NDC1HWC0;
106   }
107   AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode);
108   return priority_matched_format;
109 }
110 
111 /**
112  * Compare two vector by priority, select a better vector, like compare two num, first compare highest num location,
113  * if equal then next num location
114  * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3]
115  */
PriorityChooseItem(const std::vector<int> & cur_item,std::vector<int> * best_item)116 bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) {
117   MS_EXCEPTION_IF_NULL(best_item);
118   if (cur_item.size() != best_item->size()) {
119     MS_LOG(ERROR) << "Item size should be same!";
120     return false;
121   }
122   // Update the best_item by comparing the cur_item and best_item
123   for (size_t i = 0; i < cur_item.size(); i++) {
124     if (cur_item[i] > best_item->at(i)) {
125       *best_item = cur_item;
126       return true;
127     } else if (cur_item[i] == best_item->at(i)) {
128       continue;
129     } else {
130       return false;
131     }
132   }
133   return false;
134 }
135 
UpdateCurMatchCounts(const kernel::KernelBuildInfo & kernel_build_info,const std::shared_ptr<CNode> & kernel_node,std::vector<int> * const cur_kernelinfo_match_counts)136 void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
137                           std::vector<int> *const cur_kernelinfo_match_counts) {
138   MS_EXCEPTION_IF_NULL(kernel_node);
139   MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts);
140   if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) {
141     MS_LOG(EXCEPTION) << "Out of range cur_kernel info_match_counts " << MATCH_COUNT_PRIORITY_END;
142   }
143   auto pri_match_format = GetPriorityMatchFormat(kernel_node);
144   size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
145   for (size_t input_index = 0; input_index < input_num; ++input_index) {
146     auto input_anf_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, input_index), 0).first;
147     MS_EXCEPTION_IF_NULL(input_anf_node);
148     // we do not take ValueNode into consideration in graph kernel.
149     auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWeightInitScore;
150     if (AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) {
151       base_score = kWeightUnInitScore;
152     }
153     if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
154       (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score;
155     }
156     // we match output fix precision first.
157     auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index);
158     if (prev_device_type == kTypeUnknown) {
159       prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
160     }
161     if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) {
162       (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score;
163     }
164     if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
165       (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score;
166     }
167     if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT ||
168         kernel_build_info.GetInputFormat(input_index) == kOpFormat_NCDHW) {
169       (*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score;
170     }
171   }
172 
173   size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
174   for (size_t output_index = 0; output_index < output_num; ++output_index) {
175     // cal count of same output dtype between abstract and kernel info
176     if (kernel_build_info.GetOutputDeviceType(output_index) ==
177         AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) {
178       (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT] += 1;
179     }
180     if (kernel_build_info.GetOutputFormat(output_index) == pri_match_format) {
181       (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += 1;
182     }
183   }
184 }
185 
PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr & cnode,const std::shared_ptr<kernel::KernelBuildInfo> & selected_kernel_build_info,bool precision_reduce)186 std::string PrintRaiseOrReducePrecisionSelectedInfo(
187   const CNodePtr &cnode, const std::shared_ptr<kernel::KernelBuildInfo> &selected_kernel_build_info,
188   bool precision_reduce) {
189   MS_EXCEPTION_IF_NULL(selected_kernel_build_info);
190   MS_EXCEPTION_IF_NULL(cnode);
191   std::ostringstream buffer;
192   buffer << cnode->DebugString();
193   if (precision_reduce) {
194     buffer << " Reduce precision, node datatype: \n";
195   } else {
196     buffer << " Raise precision, node datatype: \n";
197   }
198   PrintInputAndOutputInferType(buffer, cnode);
199   buffer << ", select kernel:" << selected_kernel_build_info->ToString();
200   return buffer.str();
201 }
202 
ChooseMatchedKernelInfo(const CNodePtr & kernel_node,const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> & kernel_info_list)203 std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo(
204   const CNodePtr &kernel_node, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
205   if (kernel_info_list.empty()) {
206     return nullptr;
207   }
208   std::vector<int> most_match_counts = {-1, -1, -1, -1, -1};
209   size_t selected_index = 0;
210   for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
211     std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0};
212     auto kernel_info_ptr = kernel_info_list[info_index];
213     MS_EXCEPTION_IF_NULL(kernel_info_ptr);
214     UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts);
215     // Currently the selection policy is the match format count first, and then is datatype counts.
216     if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) {
217       selected_index = info_index;
218     }
219   }
220   return kernel_info_list[selected_index];
221 }
222 
FilteredKernelInfoByDtype(const CNodePtr & cnode,const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> & kernel_info_list)223 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
224   const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
225   std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result;
226   for (const auto &kernel_build_info : kernel_info_list) {
227     MS_EXCEPTION_IF_NULL(kernel_build_info);
228     if (!MatchInferOutputDataType(cnode, *kernel_build_info)) {
229       continue;
230     }
231     result.push_back(kernel_build_info);
232   }
233   return result;
234 }
235 
CheckHitTargetDtype(const std::map<TypeId,TypeId> & type_map,const TypeId & in_dtype,const TypeId & device_dtype,bool * flag)236 bool CheckHitTargetDtype(const std::map<TypeId, TypeId> &type_map, const TypeId &in_dtype, const TypeId &device_dtype,
237                          bool *flag) {
238   auto iter = type_map.find(in_dtype);
239   // if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false
240   if (iter == type_map.end() && in_dtype != device_dtype) {
241     return false;
242   }
243   // infer dtype in type_map, but can not find dst dtype that supported raise or reduce,
244   // or infer dtype not equal kernel info dtype, return false
245   if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
246     return false;
247   }
248   if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) {
249     *flag = true;
250   }
251   return true;
252 }
253 
TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> & kernel_build_info,const CNodePtr & cnode,const std::map<TypeId,TypeId> & type_map)254 bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
255                     const std::map<TypeId, TypeId> &type_map) {
256   // filte kernel info that unsupported raise or reduce datatype
257   MS_EXCEPTION_IF_NULL(cnode);
258   MS_EXCEPTION_IF_NULL(kernel_build_info);
259   bool flag = false;
260   for (size_t input_index = 0; input_index < kernel_build_info->GetInputNum(); ++input_index) {
261     auto in_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
262     auto device_dtype = kernel_build_info->GetInputDeviceType(input_index);
263     if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) {
264       device_dtype = kNumberTypeFloat32;
265     }
266     if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype, &flag)) {
267       return false;
268     }
269   }
270 
271   for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) {
272     auto in_dtype = AnfAlgo::GetOutputInferDataType(cnode, output_index);
273     auto device_dtype = kernel_build_info->GetOutputDeviceType(output_index);
274     if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) {
275       device_dtype = kNumberTypeFloat32;
276     }
277 
278     if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype, &flag)) {
279       return false;
280     }
281   }
282   if (flag) {
283     auto node_name = AnfAlgo::GetCNodeName(cnode);
284     MS_LOG(WARNING) << "Node:[" << node_name << "] reduce precision from int64 to int32";
285   }
286   return true;
287 }
288 
FilterRaisedOrReducePrecisionMatchedKernelInfo(const CNodePtr & cnode,const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> & kernel_info_list,bool * precision_reduce)289 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecisionMatchedKernelInfo(
290   const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list,
291   bool *precision_reduce) {
292   MS_EXCEPTION_IF_NULL(precision_reduce);
293   std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_kernel_info_list;
294   const std::map<TypeId, TypeId> raise_map = {{kNumberTypeFloat16, kNumberTypeFloat32}};
295   const std::map<TypeId, TypeId> reduce_map = {{kNumberTypeInt64, kNumberTypeInt32},
296                                                {kNumberTypeFloat, kNumberTypeFloat16},
297                                                {kNumberTypeFloat32, kNumberTypeFloat16}};
298   // raise precision
299   for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
300     MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]);
301     if (TagRaiseReduce(kernel_info_list[info_index], cnode, raise_map)) {
302       filtered_kernel_info_list.push_back(kernel_info_list[info_index]);
303     }
304   }
305 
306   if (!filtered_kernel_info_list.empty()) {
307     *precision_reduce = false;
308     return filtered_kernel_info_list;
309   }
310 
311   // reduce precision
312   auto context_ptr = MsContext::GetInstance();
313   MS_EXCEPTION_IF_NULL(context_ptr);
314   if (context_ptr->get_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION)) {
315     for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
316       MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]);
317       if (TagRaiseReduce(kernel_info_list[info_index], cnode, reduce_map)) {
318         filtered_kernel_info_list.push_back(kernel_info_list[info_index]);
319       }
320     }
321   }
322   if (!filtered_kernel_info_list.empty()) {
323     *precision_reduce = true;
324   }
325   return filtered_kernel_info_list;
326 }
327 
SetCastAndWeightFormat(const CNodePtr & kernel_node)328 void SetCastAndWeightFormat(const CNodePtr &kernel_node) {
329   MS_EXCEPTION_IF_NULL(kernel_node);
330   if (!AnfAlgo::HasNodeAttr(kAttrPynativeNextIndex, kernel_node) ||
331       !AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) {
332     MS_LOG(EXCEPTION) << "The node [" << kernel_node->DebugString() << "] attr of " << kAttrPynativeNextIndex << " or "
333                       << kAttrPynativeNextOpName << " has not been set yet!"
334                       << " trace: " << trace::DumpSourceLines(kernel_node);
335   }
336   auto next_index = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPynativeNextIndex);
337   auto next_op_name = AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAttrPynativeNextOpName);
338   auto iter = kNextOpFormatList.find(next_op_name);
339   if (iter == kNextOpFormatList.end()) {
340     MS_LOG(INFO) << "The op name " << next_op_name << "has not been set in the next op map ";
341     return;
342   }
343   if (iter->second.size() < next_index) {
344     MS_LOG(EXCEPTION) << "Next input index " << next_index << "is out of range in the next op map max size is "
345                       << iter->second.size() << " trace: " << trace::DumpSourceLines(kernel_node);
346   }
347   if (AnfAlgo::GetCNodeName(kernel_node) != prim::kPrimCast->name()) {
348     MS_LOG(INFO) << "Only supported to change the node Cast's build info!!!";
349     return;
350   }
351   auto format = iter->second[next_index];
352   auto info_builder =
353     std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(kernel_node));
354   MS_EXCEPTION_IF_NULL(info_builder);
355   info_builder->SetInputsFormat({format});
356   info_builder->SetOutputsFormat({format});
357   AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), kernel_node.get());
358 }
359 
SetWeightFormat(const AnfNodePtr & real_input_node,std::vector<string> output_format,const CNodePtr & kernel_node,size_t input_index,bool force_fresh=false)360 void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector<string> output_format, const CNodePtr &kernel_node,
361                      size_t input_index, bool force_fresh = false) {
362   MS_EXCEPTION_IF_NULL(real_input_node);
363   if (real_input_node->isa<CNode>() || AnfAlgo::OutputAddrExist(real_input_node, 0)) {
364     return;
365   }
366   auto context_ptr = MsContext::GetInstance();
367   MS_EXCEPTION_IF_NULL(context_ptr);
368   bool disable_convert = real_input_node->isa<Parameter>() || real_input_node->isa<ValueNode>();
369   if (disable_convert && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK)) {
370     disable_convert =
371       trans::kTransFormatMapOfHostToDevice.find(output_format[0]) == trans::kTransFormatMapOfHostToDevice.end();
372   }
373   // if not find in host convert format map means the host has not registered the convert function of this format
374   if (output_format[0] != kOpFormat_DEFAULT && disable_convert) {
375     output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
376   }
377   auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
378   MS_EXCEPTION_IF_NULL(builder);
379   // we set special device info of a input tensor.
380   auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node);
381   if (op_info != nullptr) {
382     force_fresh = op_info->is_ref() || force_fresh;
383   }
384   auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
385   MS_EXCEPTION_IF_NULL(selected_kernel_info);
386   if (IsValueNode<tensor::Tensor>(real_input_node) &&
387       AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) {
388     builder->SetOutputsFormat(output_format);
389     std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
390     builder->SetOutputsDeviceType(output_type);
391     AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
392     return;
393   }
394   if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || force_fresh) {
395     builder->SetOutputsFormat(output_format);
396     std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
397     builder->SetOutputsDeviceType(output_type);
398     AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
399   }
400 }
401 
RefreshCastAndParamWeightFormat(const AnfNodePtr & input_node,const string & format)402 bool RefreshCastAndParamWeightFormat(const AnfNodePtr &input_node, const string &format) {
403   MS_EXCEPTION_IF_NULL(input_node);
404   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
405     return false;
406   }
407   if (!input_node->isa<CNode>()) {
408     return false;
409   }
410   auto cast_node = input_node->cast<CNodePtr>();
411   MS_EXCEPTION_IF_NULL(cast_node);
412   if (AnfAlgo::GetCNodeName(cast_node) != prim::kPrimCast->name()) {
413     return true;
414   }
415   if (AnfAlgo::IsFeatureMapOutput(cast_node)) {
416     return true;
417   }
418   if (format == kOpFormat_FRACTAL_ZN_RNN || format == kOpFormat_ND_RNN_BIAS) {
419     return true;
420   }
421   auto info_builder =
422     std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(input_node));
423   MS_EXCEPTION_IF_NULL(info_builder);
424   info_builder->SetInputsFormat({format});
425   info_builder->SetOutputsFormat({format});
426   AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get());
427   auto cast_input_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cast_node, 0), 0);
428   SetWeightFormat(cast_input_node.first, {format}, cast_node, 0, true);
429   return true;
430 }
431 }  // namespace
SetTensorDeviceInfo(const CNodePtr & kernel_node)432 void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
433   MS_EXCEPTION_IF_NULL(kernel_node);
434   auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
435   MS_EXCEPTION_IF_NULL(selected_kernel_info);
436   size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
437   for (size_t input_index = 0; input_index < input_num; ++input_index) {
438     auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
439     MS_EXCEPTION_IF_NULL(input_kernel_node);
440     auto input_with_index = AnfAlgo::VisitKernelWithReturnType(input_kernel_node, 0);
441     MS_EXCEPTION_IF_NULL(input_with_index.first);
442     auto real_input_node = input_with_index.first;
443     MS_EXCEPTION_IF_NULL(real_input_node);
444     if (RefreshCastAndParamWeightFormat(real_input_node, selected_kernel_info->GetInputFormat(input_index))) {
445       continue;
446     }
447     if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
448       continue;
449     }
450     auto refresh_format = selected_kernel_info->GetInputFormat(input_index);
451     std::vector<std::string> output_format = {refresh_format};
452     SetWeightFormat(real_input_node, output_format, kernel_node, input_index);
453   }
454 }
455 
SetMatchedKernelInfo(const CNodePtr & kernel_node,const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> & kernel_info_list)456 KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
457                                         const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
458   MS_EXCEPTION_IF_NULL(kernel_node);
459   KernelSelectStatus select_status = kNoMatched;
460   if (kernel_info_list.empty()) {
461     return select_status;
462   }
463   bool precision_reduce = false;
464   std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
465   // Matched kernel info
466   // Filter kernel info matched with me inferred type
467   auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list);
468   if (!filtered_kernel_info_list.empty()) {
469     selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
470     select_status = kStatusAllMatched;
471   } else {
472     // selected kernel info using raised precision or reduce precision
473     filtered_kernel_info_list =
474       FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce);
475     selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
476     if (selected_kernel_info == nullptr) {
477       return select_status;
478     } else {
479       MS_LOG(INFO) << PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
480       select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
481     }
482   }
483   // Set kernel build info to node
484   MS_LOG(INFO) << "Current node: " << kernel_node->fullname_with_scope()
485                << " selected: " << selected_kernel_info->ToString();
486   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
487   // Set format and data type for input tensor.
488   if (AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) {
489     SetCastAndWeightFormat(kernel_node);
490   }
491   SetTensorDeviceInfo(kernel_node);
492   return select_status;
493 }
494 
SelectKernelInfo(const CNodePtr & kernel_node,KernelType kernel_type)495 KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
496   std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
497   std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list;
498   MS_EXCEPTION_IF_NULL(kernel_node);
499   if (AnfAlgo::IsGraphKernel(kernel_node)) {
500     auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
501     MS_EXCEPTION_IF_NULL(func_graph);
502     SelectGraphKernelInfo(kernel_node, func_graph);
503     return kStatusAllMatched;
504   }
505   kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type);
506   auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
507   // If it can node find valid ai_core kernel info, re-find in ai_cpu kernel info
508   if (select_status == kNoMatched) {
509     MS_LOG(DEBUG) << "The node [" << kernel_node->fullname_with_scope()
510                   << "] cannot find valid TBE kernel info, try to get ai_cpu kernel info";
511     kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list);
512     select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list);
513     AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
514   }
515   // The kernel info can not find in ai_cpu kernel lists and ai_core kernel lists
516   if (select_status == kNoMatched) {
517     std::ostringstream buffer;
518     PrintInputAndOutputInferType(buffer, kernel_node);
519     MS_LOG(WARNING) << ">>> The supported kernel info(input and output data type) candidates list:";
520     for (size_t index = 0; index < kernel_info_list.size(); ++index) {
521       MS_LOG(WARNING) << "Ai_core kernel info [" << index << "] :" << kernel_info_list[index]->ToString();
522     }
523     for (size_t index = 0; index < aicpu_kernel_info_list.size(); ++index) {
524       MS_LOG(WARNING) << "Ai_cpu kernel info [" << (kernel_info_list.size() + index)
525                       << "] :" << aicpu_kernel_info_list[index]->ToString();
526     }
527     if (IsPrimitiveCNode(kernel_node, prim::kPrimLabelSwitch)) {
528       auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list);
529       AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
530       // Set format and data type for input tensor.
531       SetTensorDeviceInfo(kernel_node);
532     } else {
533       MS_LOG(WARNING) << " <<<";
534       MS_LOG(EXCEPTION) << "Can not find any available operator info for operator ["
535                         << kernel_node->fullname_with_scope()
536                         << "]. Maybe don't supported the data type: " << buffer.str()
537                         << ", or maybe the operator can not supported on current platform.\n Node trace: "
538                         << trace::DumpSourceLines(kernel_node);
539     }
540   }
541   return select_status;
542 }
543 
SetKernelInfo(const CNodePtr & kernel_node,KernelType kernel_type)544 void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
545   MS_EXCEPTION_IF_NULL(kernel_node);
546   auto kernel_info = dynamic_cast<device::KernelInfo *>(kernel_node->kernel_info());
547   MS_EXCEPTION_IF_NULL(kernel_info);
548   auto kernel_build_info = kernel_info->select_kernel_build_info();
549   MS_EXCEPTION_IF_NULL(kernel_build_info);
550 
551   if (AnfAlgo::IsGraphKernel(kernel_node)) {
552     return;
553   }
554 
555   auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
556   MS_EXCEPTION_IF_NULL(builder);
557   builder->SetOriginDataFormat(kernel_build_info->GetOriginDataFormat());
558   builder->SetInputsFormat(kernel_build_info->GetAllInputFormats());
559   builder->SetInputsDeviceType(kernel_build_info->GetAllInputDeviceTypes());
560   builder->SetOutputsFormat(kernel_build_info->GetAllOutputFormats());
561   builder->SetOutputsDeviceType(kernel_build_info->GetAllOutputDeviceTypes());
562   builder->SetOpPattern(kernel_build_info->op_pattern());
563   builder->SetFusionType(kernel_build_info->fusion_type());
564 
565   auto new_kernel_type = kernel_type;
566   auto new_processor = kernel_build_info->processor();
567   if (kernel_type == UNKNOWN_KERNEL_TYPE) {
568     std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
569     std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list;
570     kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type);
571     auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
572     if (select_status != kNoMatched) {
573       new_kernel_type = TBE_KERNEL;
574       new_processor = kernel::Processor::AICORE;
575       MS_LOG(INFO) << kernel_node->fullname_with_scope() << " uses TBE_KERNEL";
576     } else {
577       kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list);
578       select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list);
579       if (select_status != kNoMatched) {
580         new_kernel_type = AICPU_KERNEL;
581         new_processor = kernel::Processor::AICPU;
582         MS_LOG(INFO) << kernel_node->fullname_with_scope() << " uses AICPU_KERNEL";
583       }
584     }
585   }
586   if (new_kernel_type == UNKNOWN_KERNEL_TYPE) {
587     new_kernel_type = AKG_KERNEL;
588     new_processor = kernel::Processor::AICORE;
589     MS_LOG(INFO) << kernel_node->fullname_with_scope() << " uses AKG_KERNEL";
590   }
591   builder->SetKernelType(new_kernel_type);
592   builder->SetProcessor(new_processor);
593   kernel_node->set_kernel_info(std::make_shared<device::KernelInfo>());
594   AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
595 }
596 }  // namespace ascend
597 }  // namespace device
598 }  // namespace mindspore
599