1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/ascend/kernel_select_ascend.h"
18
19 #include <algorithm>
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27 #include "backend/kernel_compiler/kernel_build_info.h"
28 #include "backend/kernel_compiler/kernel_query.h"
29 #include "backend/kernel_compiler/oplib/oplib.h"
30 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
31 #include "backend/session/anf_runtime_algorithm.h"
32 #include "common/trans.h"
33 #include "debug/anf_ir_dump.h"
34 #include "frontend/operator/ops.h"
35 #include "utils/ms_context.h"
36 #include "utils/ms_utils.h"
37 #include "utils/trace_base.h"
38 namespace mindspore {
39 namespace device {
40 namespace ascend {
41 namespace {
42 const int kWeightUnInitScore = 1;
43 const int kWeightInitScore = 2;
44 const int kFeatureMapBaseScore = 10;
45 constexpr auto kPriChoosenFormat = "pri_format";
46 enum MatchCountPriority : int {
47 MATCH_COUNT_PRIORITY_BEGIN = 0,
48 MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
49 MATCH_FORMAT_COUNT,
50 MATCH_SPECIAL_FORMAT_COUNT,
51 MATCH_DEFAULT_FORMAT_COUNT,
52 MATCH_OUTPUT_DTYPE_COUNT,
53 MATCH_COUNT_PRIORITY_END
54 };
55 const std::map<std::string, std::vector<std::string>> kNextOpFormatList = {
56 {prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}}};
57
MatchInferOutputDataType(const CNodePtr & cnode,const kernel::KernelBuildInfo & kernel_build_info)58 bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
59 MS_EXCEPTION_IF_NULL(cnode);
60 // Check input data type
61 for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
62 TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
63 if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
64 return false;
65 }
66 }
67 // Check output data type
68 for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
69 if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) {
70 return false;
71 }
72 }
73 return true;
74 }
75
GetPriorityMatchFormat(const CNodePtr & cnode)76 string GetPriorityMatchFormat(const CNodePtr &cnode) {
77 constexpr size_t k5dSize = 5;
78 constexpr size_t k4dSize = 4;
79 string priority_matched_format = kOpFormat_NC1HWC0;
80 bool is_init = false;
81 bool need_change_nd = false;
82 bool is_5d_input = false;
83 size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
84 for (size_t index = 0; index < input_num; ++index) {
85 auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
86 if (AnfAlgo::IsFeatureMapInput(cnode, index) &&
87 kHWSpecialFormatSet.find(pre_output_format) != kHWSpecialFormatSet.end()) {
88 priority_matched_format = !is_init ? pre_output_format : priority_matched_format;
89 is_init = true;
90 }
91 // feature map has two or more special format;
92 if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) {
93 priority_matched_format = kOpFormat_DEFAULT;
94 }
95 auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size();
96 if (input_shape_size == k5dSize) {
97 is_5d_input = true;
98 }
99 need_change_nd = (need_change_nd || (input_shape_size != k4dSize && input_shape_size > 1));
100 }
101 if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) {
102 priority_matched_format = kOpFormat_DEFAULT;
103 }
104 if (is_5d_input && priority_matched_format != kOpFormat_FRAC_NZ) {
105 priority_matched_format = kOpFormat_NDC1HWC0;
106 }
107 AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode);
108 return priority_matched_format;
109 }
110
111 /**
112 * Compare two vector by priority, select a better vector, like compare two num, first compare highest num location,
113 * if equal then next num location
114 * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3]
115 */
PriorityChooseItem(const std::vector<int> & cur_item,std::vector<int> * best_item)116 bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) {
117 MS_EXCEPTION_IF_NULL(best_item);
118 if (cur_item.size() != best_item->size()) {
119 MS_LOG(ERROR) << "Item size should be same!";
120 return false;
121 }
122 // Update the best_item by comparing the cur_item and best_item
123 for (size_t i = 0; i < cur_item.size(); i++) {
124 if (cur_item[i] > best_item->at(i)) {
125 *best_item = cur_item;
126 return true;
127 } else if (cur_item[i] == best_item->at(i)) {
128 continue;
129 } else {
130 return false;
131 }
132 }
133 return false;
134 }
135
UpdateCurMatchCounts(const kernel::KernelBuildInfo & kernel_build_info,const std::shared_ptr<CNode> & kernel_node,std::vector<int> * const cur_kernelinfo_match_counts)136 void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
137 std::vector<int> *const cur_kernelinfo_match_counts) {
138 MS_EXCEPTION_IF_NULL(kernel_node);
139 MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts);
140 if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) {
141 MS_LOG(EXCEPTION) << "Out of range cur_kernel info_match_counts " << MATCH_COUNT_PRIORITY_END;
142 }
143 auto pri_match_format = GetPriorityMatchFormat(kernel_node);
144 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
145 for (size_t input_index = 0; input_index < input_num; ++input_index) {
146 auto input_anf_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, input_index), 0).first;
147 MS_EXCEPTION_IF_NULL(input_anf_node);
148 // we do not take ValueNode into consideration in graph kernel.
149 auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWeightInitScore;
150 if (AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) {
151 base_score = kWeightUnInitScore;
152 }
153 if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
154 (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score;
155 }
156 // we match output fix precision first.
157 auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index);
158 if (prev_device_type == kTypeUnknown) {
159 prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
160 }
161 if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) {
162 (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score;
163 }
164 if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
165 (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score;
166 }
167 if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT ||
168 kernel_build_info.GetInputFormat(input_index) == kOpFormat_NCDHW) {
169 (*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score;
170 }
171 }
172
173 size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
174 for (size_t output_index = 0; output_index < output_num; ++output_index) {
175 // cal count of same output dtype between abstract and kernel info
176 if (kernel_build_info.GetOutputDeviceType(output_index) ==
177 AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) {
178 (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT] += 1;
179 }
180 if (kernel_build_info.GetOutputFormat(output_index) == pri_match_format) {
181 (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += 1;
182 }
183 }
184 }
185
PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr & cnode,const std::shared_ptr<kernel::KernelBuildInfo> & selected_kernel_build_info,bool precision_reduce)186 std::string PrintRaiseOrReducePrecisionSelectedInfo(
187 const CNodePtr &cnode, const std::shared_ptr<kernel::KernelBuildInfo> &selected_kernel_build_info,
188 bool precision_reduce) {
189 MS_EXCEPTION_IF_NULL(selected_kernel_build_info);
190 MS_EXCEPTION_IF_NULL(cnode);
191 std::ostringstream buffer;
192 buffer << cnode->DebugString();
193 if (precision_reduce) {
194 buffer << " Reduce precision, node datatype: \n";
195 } else {
196 buffer << " Raise precision, node datatype: \n";
197 }
198 PrintInputAndOutputInferType(buffer, cnode);
199 buffer << ", select kernel:" << selected_kernel_build_info->ToString();
200 return buffer.str();
201 }
202
ChooseMatchedKernelInfo(const CNodePtr & kernel_node,const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> & kernel_info_list)203 std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo(
204 const CNodePtr &kernel_node, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
205 if (kernel_info_list.empty()) {
206 return nullptr;
207 }
208 std::vector<int> most_match_counts = {-1, -1, -1, -1, -1};
209 size_t selected_index = 0;
210 for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
211 std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0};
212 auto kernel_info_ptr = kernel_info_list[info_index];
213 MS_EXCEPTION_IF_NULL(kernel_info_ptr);
214 UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts);
215 // Currently the selection policy is the match format count first, and then is datatype counts.
216 if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) {
217 selected_index = info_index;
218 }
219 }
220 return kernel_info_list[selected_index];
221 }
222
FilteredKernelInfoByDtype(const CNodePtr & cnode,const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> & kernel_info_list)223 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
224 const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
225 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result;
226 for (const auto &kernel_build_info : kernel_info_list) {
227 MS_EXCEPTION_IF_NULL(kernel_build_info);
228 if (!MatchInferOutputDataType(cnode, *kernel_build_info)) {
229 continue;
230 }
231 result.push_back(kernel_build_info);
232 }
233 return result;
234 }
235
CheckHitTargetDtype(const std::map<TypeId,TypeId> & type_map,const TypeId & in_dtype,const TypeId & device_dtype,bool * flag)236 bool CheckHitTargetDtype(const std::map<TypeId, TypeId> &type_map, const TypeId &in_dtype, const TypeId &device_dtype,
237 bool *flag) {
238 auto iter = type_map.find(in_dtype);
239 // if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false
240 if (iter == type_map.end() && in_dtype != device_dtype) {
241 return false;
242 }
243 // infer dtype in type_map, but can not find dst dtype that supported raise or reduce,
244 // or infer dtype not equal kernel info dtype, return false
245 if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
246 return false;
247 }
248 if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) {
249 *flag = true;
250 }
251 return true;
252 }
253
TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> & kernel_build_info,const CNodePtr & cnode,const std::map<TypeId,TypeId> & type_map)254 bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
255 const std::map<TypeId, TypeId> &type_map) {
256 // filte kernel info that unsupported raise or reduce datatype
257 MS_EXCEPTION_IF_NULL(cnode);
258 MS_EXCEPTION_IF_NULL(kernel_build_info);
259 bool flag = false;
260 for (size_t input_index = 0; input_index < kernel_build_info->GetInputNum(); ++input_index) {
261 auto in_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
262 auto device_dtype = kernel_build_info->GetInputDeviceType(input_index);
263 if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) {
264 device_dtype = kNumberTypeFloat32;
265 }
266 if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype, &flag)) {
267 return false;
268 }
269 }
270
271 for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) {
272 auto in_dtype = AnfAlgo::GetOutputInferDataType(cnode, output_index);
273 auto device_dtype = kernel_build_info->GetOutputDeviceType(output_index);
274 if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) {
275 device_dtype = kNumberTypeFloat32;
276 }
277
278 if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype, &flag)) {
279 return false;
280 }
281 }
282 if (flag) {
283 auto node_name = AnfAlgo::GetCNodeName(cnode);
284 MS_LOG(WARNING) << "Node:[" << node_name << "] reduce precision from int64 to int32";
285 }
286 return true;
287 }
288
FilterRaisedOrReducePrecisionMatchedKernelInfo(const CNodePtr & cnode,const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> & kernel_info_list,bool * precision_reduce)289 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecisionMatchedKernelInfo(
290 const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list,
291 bool *precision_reduce) {
292 MS_EXCEPTION_IF_NULL(precision_reduce);
293 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_kernel_info_list;
294 const std::map<TypeId, TypeId> raise_map = {{kNumberTypeFloat16, kNumberTypeFloat32}};
295 const std::map<TypeId, TypeId> reduce_map = {{kNumberTypeInt64, kNumberTypeInt32},
296 {kNumberTypeFloat, kNumberTypeFloat16},
297 {kNumberTypeFloat32, kNumberTypeFloat16}};
298 // raise precision
299 for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
300 MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]);
301 if (TagRaiseReduce(kernel_info_list[info_index], cnode, raise_map)) {
302 filtered_kernel_info_list.push_back(kernel_info_list[info_index]);
303 }
304 }
305
306 if (!filtered_kernel_info_list.empty()) {
307 *precision_reduce = false;
308 return filtered_kernel_info_list;
309 }
310
311 // reduce precision
312 auto context_ptr = MsContext::GetInstance();
313 MS_EXCEPTION_IF_NULL(context_ptr);
314 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION)) {
315 for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
316 MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]);
317 if (TagRaiseReduce(kernel_info_list[info_index], cnode, reduce_map)) {
318 filtered_kernel_info_list.push_back(kernel_info_list[info_index]);
319 }
320 }
321 }
322 if (!filtered_kernel_info_list.empty()) {
323 *precision_reduce = true;
324 }
325 return filtered_kernel_info_list;
326 }
327
SetCastAndWeightFormat(const CNodePtr & kernel_node)328 void SetCastAndWeightFormat(const CNodePtr &kernel_node) {
329 MS_EXCEPTION_IF_NULL(kernel_node);
330 if (!AnfAlgo::HasNodeAttr(kAttrPynativeNextIndex, kernel_node) ||
331 !AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) {
332 MS_LOG(EXCEPTION) << "The node [" << kernel_node->DebugString() << "] attr of " << kAttrPynativeNextIndex << " or "
333 << kAttrPynativeNextOpName << " has not been set yet!"
334 << " trace: " << trace::DumpSourceLines(kernel_node);
335 }
336 auto next_index = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPynativeNextIndex);
337 auto next_op_name = AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAttrPynativeNextOpName);
338 auto iter = kNextOpFormatList.find(next_op_name);
339 if (iter == kNextOpFormatList.end()) {
340 MS_LOG(INFO) << "The op name " << next_op_name << "has not been set in the next op map ";
341 return;
342 }
343 if (iter->second.size() < next_index) {
344 MS_LOG(EXCEPTION) << "Next input index " << next_index << "is out of range in the next op map max size is "
345 << iter->second.size() << " trace: " << trace::DumpSourceLines(kernel_node);
346 }
347 if (AnfAlgo::GetCNodeName(kernel_node) != prim::kPrimCast->name()) {
348 MS_LOG(INFO) << "Only supported to change the node Cast's build info!!!";
349 return;
350 }
351 auto format = iter->second[next_index];
352 auto info_builder =
353 std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(kernel_node));
354 MS_EXCEPTION_IF_NULL(info_builder);
355 info_builder->SetInputsFormat({format});
356 info_builder->SetOutputsFormat({format});
357 AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), kernel_node.get());
358 }
359
SetWeightFormat(const AnfNodePtr & real_input_node,std::vector<string> output_format,const CNodePtr & kernel_node,size_t input_index,bool force_fresh=false)360 void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector<string> output_format, const CNodePtr &kernel_node,
361 size_t input_index, bool force_fresh = false) {
362 MS_EXCEPTION_IF_NULL(real_input_node);
363 if (real_input_node->isa<CNode>() || AnfAlgo::OutputAddrExist(real_input_node, 0)) {
364 return;
365 }
366 auto context_ptr = MsContext::GetInstance();
367 MS_EXCEPTION_IF_NULL(context_ptr);
368 bool disable_convert = real_input_node->isa<Parameter>() || real_input_node->isa<ValueNode>();
369 if (disable_convert && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK)) {
370 disable_convert =
371 trans::kTransFormatMapOfHostToDevice.find(output_format[0]) == trans::kTransFormatMapOfHostToDevice.end();
372 }
373 // if not find in host convert format map means the host has not registered the convert function of this format
374 if (output_format[0] != kOpFormat_DEFAULT && disable_convert) {
375 output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
376 }
377 auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
378 MS_EXCEPTION_IF_NULL(builder);
379 // we set special device info of a input tensor.
380 auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node);
381 if (op_info != nullptr) {
382 force_fresh = op_info->is_ref() || force_fresh;
383 }
384 auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
385 MS_EXCEPTION_IF_NULL(selected_kernel_info);
386 if (IsValueNode<tensor::Tensor>(real_input_node) &&
387 AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) {
388 builder->SetOutputsFormat(output_format);
389 std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
390 builder->SetOutputsDeviceType(output_type);
391 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
392 return;
393 }
394 if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || force_fresh) {
395 builder->SetOutputsFormat(output_format);
396 std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
397 builder->SetOutputsDeviceType(output_type);
398 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
399 }
400 }
401
RefreshCastAndParamWeightFormat(const AnfNodePtr & input_node,const string & format)402 bool RefreshCastAndParamWeightFormat(const AnfNodePtr &input_node, const string &format) {
403 MS_EXCEPTION_IF_NULL(input_node);
404 if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
405 return false;
406 }
407 if (!input_node->isa<CNode>()) {
408 return false;
409 }
410 auto cast_node = input_node->cast<CNodePtr>();
411 MS_EXCEPTION_IF_NULL(cast_node);
412 if (AnfAlgo::GetCNodeName(cast_node) != prim::kPrimCast->name()) {
413 return true;
414 }
415 if (AnfAlgo::IsFeatureMapOutput(cast_node)) {
416 return true;
417 }
418 if (format == kOpFormat_FRACTAL_ZN_RNN || format == kOpFormat_ND_RNN_BIAS) {
419 return true;
420 }
421 auto info_builder =
422 std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(input_node));
423 MS_EXCEPTION_IF_NULL(info_builder);
424 info_builder->SetInputsFormat({format});
425 info_builder->SetOutputsFormat({format});
426 AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get());
427 auto cast_input_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cast_node, 0), 0);
428 SetWeightFormat(cast_input_node.first, {format}, cast_node, 0, true);
429 return true;
430 }
431 } // namespace
SetTensorDeviceInfo(const CNodePtr & kernel_node)432 void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
433 MS_EXCEPTION_IF_NULL(kernel_node);
434 auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
435 MS_EXCEPTION_IF_NULL(selected_kernel_info);
436 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
437 for (size_t input_index = 0; input_index < input_num; ++input_index) {
438 auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
439 MS_EXCEPTION_IF_NULL(input_kernel_node);
440 auto input_with_index = AnfAlgo::VisitKernelWithReturnType(input_kernel_node, 0);
441 MS_EXCEPTION_IF_NULL(input_with_index.first);
442 auto real_input_node = input_with_index.first;
443 MS_EXCEPTION_IF_NULL(real_input_node);
444 if (RefreshCastAndParamWeightFormat(real_input_node, selected_kernel_info->GetInputFormat(input_index))) {
445 continue;
446 }
447 if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
448 continue;
449 }
450 auto refresh_format = selected_kernel_info->GetInputFormat(input_index);
451 std::vector<std::string> output_format = {refresh_format};
452 SetWeightFormat(real_input_node, output_format, kernel_node, input_index);
453 }
454 }
455
SetMatchedKernelInfo(const CNodePtr & kernel_node,const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> & kernel_info_list)456 KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
457 const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
458 MS_EXCEPTION_IF_NULL(kernel_node);
459 KernelSelectStatus select_status = kNoMatched;
460 if (kernel_info_list.empty()) {
461 return select_status;
462 }
463 bool precision_reduce = false;
464 std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
465 // Matched kernel info
466 // Filter kernel info matched with me inferred type
467 auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list);
468 if (!filtered_kernel_info_list.empty()) {
469 selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
470 select_status = kStatusAllMatched;
471 } else {
472 // selected kernel info using raised precision or reduce precision
473 filtered_kernel_info_list =
474 FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce);
475 selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
476 if (selected_kernel_info == nullptr) {
477 return select_status;
478 } else {
479 MS_LOG(INFO) << PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
480 select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
481 }
482 }
483 // Set kernel build info to node
484 MS_LOG(INFO) << "Current node: " << kernel_node->fullname_with_scope()
485 << " selected: " << selected_kernel_info->ToString();
486 AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
487 // Set format and data type for input tensor.
488 if (AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) {
489 SetCastAndWeightFormat(kernel_node);
490 }
491 SetTensorDeviceInfo(kernel_node);
492 return select_status;
493 }
494
SelectKernelInfo(const CNodePtr & kernel_node,KernelType kernel_type)495 KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
496 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
497 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list;
498 MS_EXCEPTION_IF_NULL(kernel_node);
499 if (AnfAlgo::IsGraphKernel(kernel_node)) {
500 auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
501 MS_EXCEPTION_IF_NULL(func_graph);
502 SelectGraphKernelInfo(kernel_node, func_graph);
503 return kStatusAllMatched;
504 }
505 kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type);
506 auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
507 // If it can node find valid ai_core kernel info, re-find in ai_cpu kernel info
508 if (select_status == kNoMatched) {
509 MS_LOG(DEBUG) << "The node [" << kernel_node->fullname_with_scope()
510 << "] cannot find valid TBE kernel info, try to get ai_cpu kernel info";
511 kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list);
512 select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list);
513 AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
514 }
515 // The kernel info can not find in ai_cpu kernel lists and ai_core kernel lists
516 if (select_status == kNoMatched) {
517 std::ostringstream buffer;
518 PrintInputAndOutputInferType(buffer, kernel_node);
519 MS_LOG(WARNING) << ">>> The supported kernel info(input and output data type) candidates list:";
520 for (size_t index = 0; index < kernel_info_list.size(); ++index) {
521 MS_LOG(WARNING) << "Ai_core kernel info [" << index << "] :" << kernel_info_list[index]->ToString();
522 }
523 for (size_t index = 0; index < aicpu_kernel_info_list.size(); ++index) {
524 MS_LOG(WARNING) << "Ai_cpu kernel info [" << (kernel_info_list.size() + index)
525 << "] :" << aicpu_kernel_info_list[index]->ToString();
526 }
527 if (IsPrimitiveCNode(kernel_node, prim::kPrimLabelSwitch)) {
528 auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list);
529 AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
530 // Set format and data type for input tensor.
531 SetTensorDeviceInfo(kernel_node);
532 } else {
533 MS_LOG(WARNING) << " <<<";
534 MS_LOG(EXCEPTION) << "Can not find any available operator info for operator ["
535 << kernel_node->fullname_with_scope()
536 << "]. Maybe don't supported the data type: " << buffer.str()
537 << ", or maybe the operator can not supported on current platform.\n Node trace: "
538 << trace::DumpSourceLines(kernel_node);
539 }
540 }
541 return select_status;
542 }
543
SetKernelInfo(const CNodePtr & kernel_node,KernelType kernel_type)544 void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
545 MS_EXCEPTION_IF_NULL(kernel_node);
546 auto kernel_info = dynamic_cast<device::KernelInfo *>(kernel_node->kernel_info());
547 MS_EXCEPTION_IF_NULL(kernel_info);
548 auto kernel_build_info = kernel_info->select_kernel_build_info();
549 MS_EXCEPTION_IF_NULL(kernel_build_info);
550
551 if (AnfAlgo::IsGraphKernel(kernel_node)) {
552 return;
553 }
554
555 auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
556 MS_EXCEPTION_IF_NULL(builder);
557 builder->SetOriginDataFormat(kernel_build_info->GetOriginDataFormat());
558 builder->SetInputsFormat(kernel_build_info->GetAllInputFormats());
559 builder->SetInputsDeviceType(kernel_build_info->GetAllInputDeviceTypes());
560 builder->SetOutputsFormat(kernel_build_info->GetAllOutputFormats());
561 builder->SetOutputsDeviceType(kernel_build_info->GetAllOutputDeviceTypes());
562 builder->SetOpPattern(kernel_build_info->op_pattern());
563 builder->SetFusionType(kernel_build_info->fusion_type());
564
565 auto new_kernel_type = kernel_type;
566 auto new_processor = kernel_build_info->processor();
567 if (kernel_type == UNKNOWN_KERNEL_TYPE) {
568 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
569 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list;
570 kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type);
571 auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
572 if (select_status != kNoMatched) {
573 new_kernel_type = TBE_KERNEL;
574 new_processor = kernel::Processor::AICORE;
575 MS_LOG(INFO) << kernel_node->fullname_with_scope() << " uses TBE_KERNEL";
576 } else {
577 kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list);
578 select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list);
579 if (select_status != kNoMatched) {
580 new_kernel_type = AICPU_KERNEL;
581 new_processor = kernel::Processor::AICPU;
582 MS_LOG(INFO) << kernel_node->fullname_with_scope() << " uses AICPU_KERNEL";
583 }
584 }
585 }
586 if (new_kernel_type == UNKNOWN_KERNEL_TYPE) {
587 new_kernel_type = AKG_KERNEL;
588 new_processor = kernel::Processor::AICORE;
589 MS_LOG(INFO) << kernel_node->fullname_with_scope() << " uses AKG_KERNEL";
590 }
591 builder->SetKernelType(new_kernel_type);
592 builder->SetProcessor(new_processor);
593 kernel_node->set_kernel_info(std::make_shared<device::KernelInfo>());
594 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
595 }
596 } // namespace ascend
597 } // namespace device
598 } // namespace mindspore
599