• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "plugin/device/gpu/hal/device/kernel_info_setter.h"
18 #include <algorithm>
19 #include <memory>
20 #include <tuple>
21 #include <string>
22 #include <set>
23 #include "kernel/framework_utils.h"
24 #include "ops/random_op_name.h"
25 #include "ops/nn_optimizer_op_name.h"
26 #include "ops/sparse_ops.h"
27 #include "ops/conv_pool_ops.h"
28 #include "ops/nn_ops.h"
29 #include "ops/array_ops.h"
30 #include "ops/framework_ops.h"
31 #include "kernel/common_utils.h"
32 #include "plugin/factory/ms_factory.h"
33 #include "plugin/device/gpu/kernel/gpu_kernel.h"
34 #include "kernel/kernel.h"
35 #include "kernel/kernel_build_info.h"
36 #include "kernel/oplib/opinfo.h"
37 #include "kernel/oplib/oplib.h"
38 #include "include/backend/anf_runtime_algorithm.h"
39 #include "include/common/utils/anfalgo.h"
40 #include "plugin/device/gpu/kernel/custom/custom_aot_gpu_kernel.h"
41 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
42 #include "utils/ms_context.h"
43 #include "utils/ms_utils.h"
44 #include "include/common/utils/utils.h"
45 #include "ops/op_name.h"
46 #include "plugin/device/cpu/kernel/cpu_kernel.h"
47 
48 namespace mindspore {
49 namespace device {
50 namespace gpu {
51 using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
52 using mindspore::kernel::KernelBuildInfo;
53 namespace {
54 constexpr auto kPatternOpaque = "Opaque";
55 static const std::set<std::string> kVmapGPUWhiteList = {kUnsortedSegmentSumOpName,
56                                                         kUnsortedSegmentProdOpName,
57                                                         kUniqueWithPadOpName,
58                                                         kMaskedFillOpName,
59                                                         kDataFormatDimMapOpName,
60                                                         kInstanceNormOpName,
61                                                         kInstanceNormGradOpName,
62                                                         kRandomChoiceWithMaskOpName,
63                                                         kAdamOpName,
64                                                         kSplitOpName,
65                                                         kApplyAdagradDAOpName,
66                                                         kApplyRMSPropOpName,
67                                                         kApplyCenteredRMSPropOpName,
68                                                         kRandomShuffleOpName,
69                                                         kApplyAdamWithAmsgradOpName,
70                                                         kApplyAdamWithAmsgradV2OpName,
71                                                         kApplyProximalAdagradOpName,
72                                                         kMatrixBandPartOpName,
73                                                         kDiagOpName,
74                                                         kSparseSegmentMeanOpName};
75 
GetImplyType(KernelType kernel_type)76 kernel::OpImplyType GetImplyType(KernelType kernel_type) {
77   kernel::OpImplyType imply_type =
78     kernel_type == KernelType::GPU_KERNEL ? kernel::OpImplyType::kImplyGPU : kernel::OpImplyType::kImplyAKG;
79   return imply_type;
80 }
81 
CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> & alternative_kernel_info,const std::shared_ptr<KernelBuildInfo> & selected_kernel_info,bool match_none=false)82 bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info,
83                      const std::shared_ptr<KernelBuildInfo> &selected_kernel_info, bool match_none = false) {
84   MS_EXCEPTION_IF_NULL(selected_kernel_info);
85   MS_EXCEPTION_IF_NULL(alternative_kernel_info);
86   size_t selected_input_num = selected_kernel_info->GetInputNum();
87   size_t alternative_input_num = alternative_kernel_info->GetInputNum();
88   if (selected_input_num != alternative_input_num) {
89     return false;
90   }
91   for (size_t i = 0; i < selected_input_num; i++) {
92     auto format = alternative_kernel_info->GetInputFormat(i);
93     if (selected_kernel_info->GetInputFormat(i) != format && (!match_none || !format.empty())) {
94       return false;
95     }
96     auto type = alternative_kernel_info->GetInputDeviceType(i);
97     if (selected_kernel_info->GetInputDeviceType(i) != type && (!match_none || type != TypeId::kMetaTypeNone)) {
98       return false;
99     }
100   }
101 
102   size_t selected_output_num = selected_kernel_info->GetOutputNum();
103   size_t alternative_output_num = alternative_kernel_info->GetOutputNum();
104   if (selected_output_num != alternative_output_num) {
105     return false;
106   }
107   for (size_t i = 0; i < selected_output_num; i++) {
108     auto format = alternative_kernel_info->GetOutputFormat(i);
109     if (selected_kernel_info->GetOutputFormat(i) != format && (!match_none || !format.empty())) {
110       return false;
111     }
112     auto type = alternative_kernel_info->GetOutputDeviceType(i);
113     if (selected_kernel_info->GetOutputDeviceType(i) != type && (!match_none || type != TypeId::kMetaTypeNone)) {
114       return false;
115     }
116   }
117   return true;
118 }
119 
GetSupportedTypesStr(const CNodePtr & kernel_node,KernelType kernel_type)120 std::string GetSupportedTypesStr(const CNodePtr &kernel_node, KernelType kernel_type) {
121   MS_EXCEPTION_IF_NULL(kernel_node);
122   std::string supported_type_lists;
123   // Custom op gets reg info from OpLib instead of NativeGpuKernelMod.
124   if (!IsPrimitiveCNode(kernel_node, prim::kPrimCustom)) {
125     auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
126     if (kernel::Factory<kernel::NativeGpuKernelMod>::Instance().IsRegistered(kernel_name)) {
127       auto kernel_attr_list = kernel::NativeGpuKernelMod::GetGpuSupportedList(kernel_name);
128       if (!kernel_attr_list.empty()) {
129         for (size_t attr_index = 0; attr_index < kernel_attr_list.size(); ++attr_index) {
130           std::string type_list = "input[";
131           auto attr = kernel_attr_list[attr_index];
132           for (size_t input_index = 0; input_index < attr.GetInputSize(); ++input_index) {
133             type_list = type_list + TypeIdToString(attr.GetInputAttr(input_index).dtype) +
134                         ((input_index == (attr.GetInputSize() - 1)) ? "" : " ");
135           }
136           type_list = type_list + "], output[";
137           for (size_t input_index = 0; input_index < attr.GetOutputSize(); ++input_index) {
138             type_list = type_list + TypeIdToString(attr.GetOutputAttr(input_index).dtype) +
139                         ((input_index == (attr.GetOutputSize() - 1)) ? "" : " ");
140           }
141           supported_type_lists = supported_type_lists + type_list + "]; ";
142         }
143 
144         return supported_type_lists;
145       }
146     } else {
147       supported_type_lists =
148         kernel::NativeGpuKernelModFactory::GetInstance().SupportedTypeList(common::AnfAlgo::GetCNodeName(kernel_node));
149       if (!supported_type_lists.empty()) {
150         return supported_type_lists;
151       }
152     }
153   }
154 
155   std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
156   std::string op_name = common::AnfAlgo::GetCNodeName(kernel_node);
157   kernel::OpImplyType imply_type = GetImplyType(kernel_type);
158   auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, imply_type);
159   if (op_info_ptr == nullptr) {
160     return supported_type_lists;
161   }
162   (void)ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list);
163   for (size_t i = 0; i < kernel_info_list.size(); i++) {
164     auto supported_akg_type = kernel_info_list[i]->GetAllInputDeviceTypes();
165     auto supported_akg_type_out = kernel_info_list[i]->GetAllOutputDeviceTypes();
166     std::string supported_akg_type_list = "input[";
167     for (auto type : supported_akg_type) {
168       supported_akg_type_list = supported_akg_type_list + TypeIdToString(type) + " ";
169     }
170     supported_type_lists = supported_type_lists + supported_akg_type_list + "], output[";
171     supported_akg_type_list.clear();
172     for (auto type : supported_akg_type_out) {
173       supported_akg_type_list = supported_akg_type_list + TypeIdToString(type) + " ";
174     }
175     supported_type_lists = supported_type_lists + supported_akg_type_list + "]; ";
176   }
177 
178   return supported_type_lists;
179 }
180 
SelectAkgKernel(const CNodePtr & kernel_node,const std::shared_ptr<KernelBuildInfo> & selected_kernel_info)181 bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
182   MS_EXCEPTION_IF_NULL(kernel_node);
183   MS_EXCEPTION_IF_NULL(selected_kernel_info);
184   std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
185   if (common::AnfAlgo::IsNodeInGraphKernel(kernel_node)) {
186     // The op_info in OpLib is only used for basic ops,
187     // we don't care it in GraphKernel.
188     return true;
189   }
190 
191   std::string op_name = common::AnfAlgo::GetCNodeName(kernel_node);
192 
193   auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kImplyAKG);
194   if (op_info_ptr == nullptr) {
195     MS_LOG(DEBUG) << "Not find op[" << op_name << "] in akg";
196     return false;
197   }
198   if (!ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list)) {
199     MS_LOG(EXCEPTION) << "Parsed metadata of op[" << op_name << "] failed.";
200   }
201   if (kernel_info_list.empty()) {
202     MS_LOG(EXCEPTION) << "Akg dose not has metadata of op[" << op_name << "].";
203   }
204 
205   bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
206                            [&](const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info) {
207                              return CheckKernelInfo(alternative_kernel_info, selected_kernel_info);
208                            });
209   if (!match) {
210     MS_LOG(ERROR) << "Not find op[" << op_name << "] which both match data type and format in akg";
211     return false;
212   }
213   return true;
214 }
215 
SelectCustomKernel(const CNodePtr & kernel_node,const std::shared_ptr<KernelBuildInfo> & selected_kernel_info,KernelType * kernel_type)216 bool SelectCustomKernel(const CNodePtr &kernel_node, const std::shared_ptr<KernelBuildInfo> &selected_kernel_info,
217                         KernelType *kernel_type) {
218   MS_EXCEPTION_IF_NULL(kernel_node);
219   MS_EXCEPTION_IF_NULL(selected_kernel_info);
220   MS_EXCEPTION_IF_NULL(kernel_type);
221   std::string op_name = common::AnfAlgo::GetCNodeName(kernel_node);
222   // Custom op's kernel type can be one of [GPU_KERNEL, AKG_KERNEL] on GPU
223   auto func_type = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAttrFuncType);
224   if (func_type == kCustomTypeAOT) {
225     *kernel_type = KernelType::GPU_KERNEL;
226     if (!kernel::Factory<kernel::NativeGpuKernelMod>::Instance().IsRegistered(op_name)) {
227       kernel::Factory<kernel::NativeGpuKernelMod>::Instance().Register(
228         op_name, []() { return std::make_shared<kernel::CustomAOTGpuKernelMod>(); });
229     }
230   } else if (IsOneOfCustomAkgType(func_type)) {
231     *kernel_type = KernelType::AKG_KERNEL;
232   } else {
233     MS_LOG(EXCEPTION) << "Unsupported func type [" << func_type << "] for Custom op [" << op_name << "] on GPU";
234   }
235   kernel::OpImplyType imply_type = GetImplyType(*kernel_type);
236   auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, imply_type);
237   // If Custom op has not set reg info,
238   // or the no info about inputs in reg info(the case of undetermined input size),
239   // then infer info from inputs
240   if (op_info_ptr == nullptr || op_info_ptr->inputs_ptr().size() == 0) {
241     MS_LOG(INFO) << "Not find operator information for op[" << op_name << "]. Infer operator information from inputs.";
242     return true;
243   }
244   std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
245   if (!ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list)) {
246     MS_LOG(EXCEPTION) << "Parsed metadata of op[" << op_name << "] failed.";
247   }
248   if (kernel_info_list.empty()) {
249     MS_LOG(EXCEPTION) << "Not find valid metadata of op[" << op_name << "].";
250   }
251   bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
252                            [&](const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info) {
253                              return CheckKernelInfo(alternative_kernel_info, selected_kernel_info, true);
254                            });
255   if (!match) {
256     MS_LOG(ERROR) << "Not find op[" << op_name << "] which both match data type and format.";
257     return false;
258   }
259   return true;
260 }
261 
SetTensorDeviceInfo(const kernel::KernelBuildInfo & selected_kernel_info,const CNodePtr & kernel_node,const std::vector<std::tuple<size_t,TypeId,TypeId>> & input_reduce_detail)262 void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node,
263                          const std::vector<std::tuple<size_t, TypeId, TypeId>> &input_reduce_detail) {
264   MS_EXCEPTION_IF_NULL(kernel_node);
265   size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
266   for (size_t input_index = 0; input_index < input_num; ++input_index) {
267     auto input_kernel_node = kernel_node->input(input_index + 1);
268     MS_EXCEPTION_IF_NULL(input_kernel_node);
269     auto input_with_index = common::AnfAlgo::VisitKernel(input_kernel_node, 0);
270     MS_EXCEPTION_IF_NULL(input_with_index.first);
271     auto real_input_node = input_with_index.first;
272     if (!real_input_node->isa<Parameter>()) {
273       continue;
274     }
275     std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
276       std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
277 
278     auto param = real_input_node->cast<ParameterPtr>();
279     MS_EXCEPTION_IF_NULL(param);
280     if (!common::AnfAlgo::IsParameterWeight(param)) {
281       std::vector<std::string> output_format = {kOpFormat_DEFAULT};
282       builder->SetOutputsFormat(output_format);
283       std::vector<TypeId> output_type = {common::AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
284       builder->SetOutputsDeviceType(output_type);
285       AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
286       continue;
287     }
288     if ((AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) ||
289         (common::AnfAlgo::GetCNodeName(kernel_node) == "ApplyMomentum")) {
290       std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
291       builder->SetOutputsFormat(output_format);
292       std::vector<TypeId> output_type;
293       auto reduce_flag = kernel::NativeGpuKernelModFactory::GetInstance().reduce_flag_;
294       if (std::find(reduce_flag.first.begin(), reduce_flag.first.end(), input_index) != reduce_flag.first.end()) {
295         output_type = {reduce_flag.second};
296       } else {
297         auto iter = std::find_if(input_reduce_detail.begin(), input_reduce_detail.end(),
298                                  [input_index](const std::tuple<size_t, TypeId, TypeId> &reduce_detail) {
299                                    return std::get<0>(reduce_detail) == input_index;
300                                  });
301         if (iter != input_reduce_detail.end()) {
302           output_type = {std::get<1>(*iter)};
303         } else {
304           output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
305         }
306       }
307       builder->SetOutputsDeviceType(output_type);
308       AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
309     }
310   }
311   kernel::NativeGpuKernelModFactory::GetInstance().reduce_flag_.first.clear();
312 }
313 
TransformFormatPosition(std::vector<size_t> * format_position,size_t position_num)314 void TransformFormatPosition(std::vector<size_t> *format_position, size_t position_num) {
315   MS_EXCEPTION_IF_NULL(format_position);
316   if (format_position->size() == 0) {
317     return;
318   }
319 
320   // If the inserted position is kAllPositions, then insert all the positions.
321   if ((*format_position)[0] == kAllPositions) {
322     format_position->clear();
323     for (size_t index = 0; index < position_num; index++) {
324       format_position->push_back(index);
325     }
326   }
327 }
328 
IsNeedProcessFormatInfo(const CNodePtr & kernel_node,const std::vector<TypeId> & inputs_type)329 bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) {
330   MS_EXCEPTION_IF_NULL(kernel_node);
331   if (!FormatTransformChecker::GetInstance().format_transform()) {
332     return false;
333   }
334   if (!AnfUtils::IsRealCNodeKernel(kernel_node)) {
335     return false;
336   }
337   auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
338   auto iter = kKernelFormatPositionMap.find(kernel_name);
339   if (iter == kKernelFormatPositionMap.end()) {
340     return false;
341   }
342   if (inputs_type.size() == 0) {
343     return false;
344   }
345 
346   auto inputs_format_position = iter->second.first;
347   size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
348   TransformFormatPosition(&inputs_format_position, input_num);
349   for (const auto &input_format_position : inputs_format_position) {
350     auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_format_position);
351     // Only support the transformer between NCHW and NHWC, so need the shape is 4 dimension.
352     if (input_shape.size() != kFormatTransformDimension) {
353       return false;
354     }
355   }
356 
357   auto outputs_format_position = iter->second.second;
358   size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
359   TransformFormatPosition(&outputs_format_position, output_num);
360   for (const auto &output_format_position : outputs_format_position) {
361     auto output_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, output_format_position);
362     // Only support the transformer between NCHW and NHWC, so need the shape is 4 dimension.
363     if (output_shape.size() != kFormatTransformDimension) {
364       return false;
365     }
366   }
367   return true;
368 }
369 
UpdateKernelFormatInfo(const CNodePtr & kernel_node,const std::vector<TypeId> & inputs_type,std::vector<std::string> * inputs_format,std::vector<std::string> * outputs_format,std::string * origin_data_format)370 void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type,
371                             std::vector<std::string> *inputs_format, std::vector<std::string> *outputs_format,
372                             std::string *origin_data_format) {
373   MS_EXCEPTION_IF_NULL(kernel_node);
374   MS_EXCEPTION_IF_NULL(inputs_format);
375   MS_EXCEPTION_IF_NULL(outputs_format);
376   auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
377   auto iter = kKernelFormatPositionMap.find(kernel_name);
378   if (iter == kKernelFormatPositionMap.end()) {
379     return;
380   }
381   auto cal_format = (inputs_type[0] == kNumberTypeFloat16) ? kOpFormat_NHWC : kOpFormat_NCHW;
382   MS_LOG(DEBUG) << "Kernel node: " << kernel_node->fullname_with_scope() << ", format: " << cal_format;
383   auto inputs_format_position = iter->second.first;
384   size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
385   TransformFormatPosition(&inputs_format_position, input_num);
386   for (const auto &input_format_position : inputs_format_position) {
387     if (input_format_position >= inputs_format->size()) {
388       MS_LOG(EXCEPTION) << "The position [" << input_format_position << "] is out of range of the input size ["
389                         << inputs_format->size() << "] #kernel_node [" << kernel_node->fullname_with_scope() << "]";
390     }
391     (*inputs_format)[input_format_position] = cal_format;
392   }
393 
394   auto outputs_format_position = iter->second.second;
395   size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
396   TransformFormatPosition(&outputs_format_position, output_num);
397   for (const auto &output_format_position : outputs_format_position) {
398     if (output_format_position >= outputs_format->size()) {
399       MS_LOG(EXCEPTION) << "The position [" << output_format_position << "] is out of range of the output size ["
400                         << outputs_format->size() << "] #kernel_node [" << kernel_node->fullname_with_scope() << "]";
401     }
402     (*outputs_format)[output_format_position] = cal_format;
403   }
404   auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
405   MS_EXCEPTION_IF_NULL(prim);
406   if (prim->HasAttr("format")) {
407     *origin_data_format = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, "format");
408   }
409 }
410 
SetGraphKernelInfo(const CNodePtr & kernel_node,const FuncGraphPtr & func_graph)411 void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) {
412   MS_EXCEPTION_IF_NULL(kernel_node);
413   MS_EXCEPTION_IF_NULL(func_graph);
414   std::vector<AnfNodePtr> node_list;
415   std::vector<AnfNodePtr> input_list;
416   std::vector<AnfNodePtr> output_list;
417   kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
418 
419   std::vector<std::string> graph_input_format;
420   std::vector<TypeId> graph_input_type;
421   // set graph kernel inputs kernel info.
422   for (size_t i = 0; i < input_list.size(); ++i) {
423     kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
424     std::vector<std::string> outputs_format = {kOpFormat_DEFAULT};
425     std::vector<TypeId> outputs_device_type = {common::AnfAlgo::GetOutputInferDataType(input_list[i], 0)};
426     graph_input_format.push_back(kOpFormat_DEFAULT);
427     graph_input_type.push_back(common::AnfAlgo::GetOutputInferDataType(input_list[i], 0));
428     builder.SetOutputsFormat(outputs_format);
429     builder.SetOutputsDeviceType(outputs_device_type);
430     AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
431   }
432 
433   // set graph kernel innner nodes kernel info.
434   auto kernel_info_setter = GraphKernelInfoManager::Instance().GetGraphKernelInfo(kGPUDevice);
435   MS_EXCEPTION_IF_NULL(kernel_info_setter);
436   for (size_t i = 0; i < node_list.size(); ++i) {
437     const auto &anf_node = node_list[i];
438     MS_EXCEPTION_IF_NULL(anf_node);
439     auto cnode = anf_node->cast<CNodePtr>();
440     cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
441     kernel_info_setter->SetKernelInfo(cnode, KernelType::AKG_KERNEL);
442   }
443 
444   // set graph kernel node kernel info.
445   auto mng = func_graph->manager();
446   if (mng == nullptr) {
447     mng = Manage(func_graph, true);
448     func_graph->set_manager(mng);
449   }
450   auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
451   std::vector<std::string> graph_output_format;
452   std::vector<TypeId> graph_output_type;
453   for (size_t i = 0; i < output_index.size(); ++i) {
454     auto const &output = output_index[i];
455     graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second));
456     graph_output_type.push_back(AnfAlgo::GetOutputDeviceDataType(output.first, output.second));
457   }
458 
459   kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
460   graph_info_builder.SetInputsFormat(graph_input_format);
461   graph_info_builder.SetInputsDeviceType(graph_input_type);
462   graph_info_builder.SetOutputsFormat(graph_output_format);
463   graph_info_builder.SetOutputsDeviceType(graph_output_type);
464   graph_info_builder.SetProcessor(kernel::Processor::CUDA);
465   graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
466   graph_info_builder.SetFusionType(kPatternOpaque);
467   auto graph_selected_info = graph_info_builder.Build();
468   MS_EXCEPTION_IF_NULL(graph_selected_info);
469   AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());
470   SetTensorDeviceInfo(*graph_selected_info, kernel_node, {});
471 }
472 
PrintUnsupportedTypeWarning(const CNodePtr & kernel_node,const std::vector<TypeId> & inputs_type,const std::vector<TypeId> & outputs_type,KernelType kernel_type)473 std::pair<std::string, ExceptionType> PrintUnsupportedTypeWarning(const CNodePtr &kernel_node,
474                                                                   const std::vector<TypeId> &inputs_type,
475                                                                   const std::vector<TypeId> &outputs_type,
476                                                                   KernelType kernel_type) {
477   auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
478   std::string build_type = "input[";
479   std::for_each(std::begin(inputs_type), std::end(inputs_type),
480                 [&build_type](auto i) { build_type += TypeIdToString(i) + " "; });
481   build_type += "] and output[";
482   std::for_each(std::begin(outputs_type), std::end(outputs_type),
483                 [&build_type](auto i) { build_type += TypeIdToString(i) + " "; });
484   build_type += "]";
485   auto supported_type_lists = GetSupportedTypesStr(kernel_node, kernel_type);
486   std::stringstream ss;
487   ExceptionType etype;
488   if (supported_type_lists.empty()) {
489     ss << "Unsupported op [" << kernel_name << "] on GPU, Please confirm whether the device target setting is correct, "
490        << "or refer to 'mindspore.ops' at https://www.mindspore.cn to query the operator support list."
491        << trace::DumpSourceLines(kernel_node);
492     etype = NotSupportError;
493   } else {
494     ss << "Select GPU operator[" << kernel_name << "] fail! Unsupported data type!\nThe supported data types are "
495        << supported_type_lists << ", but get " << build_type << trace::DumpSourceLines(kernel_node);
496     etype = TypeError;
497   }
498   return {ss.str(), etype};
499 }
500 }  // namespace
501 
CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> & kernel_graph)502 void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
503   MS_EXCEPTION_IF_NULL(kernel_graph);
504   auto ms_context = MsContext::GetInstance();
505   MS_EXCEPTION_IF_NULL(ms_context);
506   if (ms_context->get_param<bool>(MS_CTX_DISABLE_FORMAT_TRANSFORM)) {
507     MS_LOG(INFO) << "Disable the automatic format transform function.";
508     format_transform_ = false;
509     return;
510   }
511   if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
512     format_transform_ = false;
513     return;
514   }
515   // Format transform will case the different infer shape and device shape, so the dynamic shape graph can't be support.
516   if (kernel_graph->is_dynamic_shape()) {
517     MS_LOG(INFO) << "Dynamic shape doesn't support the automatic format transform function.";
518     format_transform_ = false;
519     return;
520   }
521 
522   // TensorCore can be used only in Volta or newer devices.
523   const int marjor_sm = GET_MAJOR_SM;
524   if (marjor_sm < RECOMMEND_SM) {
525     format_transform_ = false;
526     return;
527   }
528   auto kernels = kernel_graph->execution_order();
529   size_t conv_cnt = 0;
530   size_t bn_cnt = 0;
531   for (const auto &kernel : kernels) {
532     auto kernel_name = common::AnfAlgo::GetCNodeName(kernel);
533     if (kernel_name == prim::kPrimLayerNorm->name()) {
534       format_transform_ = false;
535       return;
536     }
537     auto value = common::AnfAlgo::GetCNodePrimitive(kernel);
538     if (value != nullptr && value->GetAttr("format") != nullptr &&
539         GetValue<std::string>(value->GetAttr("format")) == kOpFormat_NHWC) {
540       format_transform_ = false;
541       return;
542     }
543     if (kernel_name == prim::kPrimConv2D->name()) {
544       conv_cnt++;
545     }
546     if (kernel_name == prim::kPrimBatchNorm->name()) {
547       bn_cnt++;
548     }
549   }
550   if (conv_cnt + bn_cnt > 1) {
551     format_transform_ = true;
552     return;
553   }
554   format_transform_ = false;
555 }
556 
GetSelectKernelResult(const CNodePtr & kernel_node,const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> & builder,KernelType * kernel_type,std::vector<std::tuple<size_t,TypeId,TypeId>> * input_reduce_index)557 bool GetSelectKernelResult(const CNodePtr &kernel_node,
558                            const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder,
559                            KernelType *kernel_type,
560                            std::vector<std::tuple<size_t, TypeId, TypeId>> *input_reduce_index) {
561   MS_EXCEPTION_IF_NULL(builder);
562   MS_EXCEPTION_IF_NULL(kernel_type);
563   MS_EXCEPTION_IF_NULL(input_reduce_index);
564   bool result = false;
565   std::vector<std::tuple<size_t, TypeId, TypeId>> output_reduce_index;
566   if (IsPrimitiveCNode(kernel_node, prim::kPrimCustom)) {
567     // Custom op select kernel from OpLib
568     result = SelectCustomKernel(kernel_node, builder->Build(), kernel_type);
569   } else if (*kernel_type == UNKNOWN_KERNEL_TYPE) {
570     auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
571     if (kernel::Factory<kernel::NativeGpuKernelMod>::Instance().IsRegistered(kernel_name)) {
572       result = kernel::NativeGpuKernelMod::GpuCheckSupport(kernel_name, GetKernelAttrFromBuildInfo(builder->Build()));
573       if (!result) {
574         std::tie(result, *input_reduce_index, output_reduce_index) =
575           kernel::NativeGpuKernelMod::GpuReducePrecisionCheck(kernel_name,
576                                                               GetKernelAttrFromBuildInfo(builder->Build()));
577         if (result) {
578           const size_t kReduceToTypeIdx = 2;
579           for (const auto &item : *input_reduce_index) {
580             auto idx = std::get<0>(item);
581             auto to_type_id = std::get<kReduceToTypeIdx>(item);
582             builder->SetInputDeviceType(to_type_id, idx);
583           }
584           for (const auto &item : output_reduce_index) {
585             auto idx = std::get<0>(item);
586             auto to_type_id = std::get<kReduceToTypeIdx>(item);
587             builder->SetOutputDeviceType(to_type_id, idx);
588           }
589         }
590       }
591     } else {
592       result = kernel::NativeGpuKernelModFactory::GetInstance().SearchRegistered(
593         common::AnfAlgo::GetCNodeName(kernel_node), builder->Build());
594       if (!result) {
595         result = kernel::NativeGpuKernelModFactory::GetInstance().ReducePrecision(
596           common::AnfAlgo::GetCNodeName(kernel_node), builder);
597       }
598     }
599 
600     if (!result && (!common::AnfAlgo::IsBpropCutOpExecInBackend(kernel_node))) {
601       result = SelectAkgKernel(kernel_node, builder->Build());
602       *kernel_type = AKG_KERNEL;
603     }
604   } else if (*kernel_type == AKG_KERNEL) {
605     result = SelectAkgKernel(kernel_node, builder->Build());
606   }
607   return result;
608 }
609 
GetSelectKernelObjectTypeResult(const CNodePtr & kernel_node,KernelType kernel_type)610 std::pair<bool, std::pair<std::string, ExceptionType>> GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node,
611                                                                                        KernelType kernel_type) {
612   auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
613   // Only the kernel nodes that register kernel attr can support the backoff.
614   bool backoff_support_condition =
615     ((kernel_type == UNKNOWN_KERNEL_TYPE) && !IsPrimitiveCNode(kernel_node, prim::kPrimCustom) &&
616      !common::AnfAlgo::IsGraphKernel(kernel_node));
617   std::vector<kernel::KernelAttr> kernel_attrs;
618   if (kernel::NativeGpuKernelModFactory::GetInstance().IsRegistered(kernel_name)) {
619     kernel_attrs = kernel::NativeGpuKernelModFactory::GetInstance().GetGpuSupportedList(kernel_name);
620   } else if (backoff_support_condition) {
621     // Kernel that is not supported can try to backed off on CPU and use the CPU kernel attrs to set object type.
622     kernel_attrs = kernel::NativeCpuKernelMod::GetCpuSupportedList(kernel_name);
623   }
624 
625   // Some dynamic kernels may not set the kernel attrs on GPU. Skip check only supports the tuple fold when KeepTuple
626   // is not apply target.
627   if (kernel_attrs.empty() || kernel_attrs[0].GetSkipCheck()) {
628     std::vector<kernel::KernelObjectType> input_object_types;
629     std::vector<kernel::KernelObjectType> output_object_types;
630     std::vector<kernel::KernelObjectType> output_element_object_types;
631     if (!kernel_attrs.empty() && common::AnfAlgo::HasNodeAttr(kInputRealTuple, kernel_node)) {
632       input_object_types = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAllInputObjectType(kernel_node));
633     } else {
634       input_object_types = kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAllInputObjectType(kernel_node));
635     }
636 
637     if (!kernel_attrs.empty() && common::AnfAlgo::HasNodeAttr(kOutputRealTuple, kernel_node)) {
638       output_object_types = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAllOutputObjectType(kernel_node));
639     } else {
640       output_object_types =
641         kernel::TypeIdToKernelObjectTypeForTupleUnfold(AnfAlgo::GetAllOutputObjectType(kernel_node));
642     }
643 
644     if ((!kernel_attrs.empty()) && kernel_attrs[0].GetSkipCheck() && output_object_types.size() == 1 &&
645         output_object_types[0] == kernel::KernelObjectType::TUPLE_UNFOLD) {
646       size_t output_num = kernel::GetOutputNum(kernel_node);
647       for (size_t i = 0; i < output_num; ++i) {
648         auto object_type_ptr = common::AnfAlgo::GetOutputInferType(kernel_node, i);
649         MS_EXCEPTION_IF_NULL(object_type_ptr);
650         output_element_object_types.emplace_back(kernel::TypeIdToKernelObjectType(object_type_ptr->type_id()));
651       }
652     }
653 
654     MS_LOG(DEBUG) << "Set kernel object type build info for node:" << kernel_node->DebugString()
655                   << " output type:" << output_object_types << " output element type:" << output_element_object_types;
656     kernel::SetKernelObjectTypeBuildInfo(kernel_node, input_object_types, output_object_types,
657                                          output_element_object_types);
658     if (!kernel_attrs.empty()) {
659       auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
660       kernel_build_info->SetOpType(kernel::OpType::SKIP);
661     }
662     return {true, {}};
663   }
664 
665   std::vector<kernel::KernelAttr> object_selected_kernel_attrs;
666   if (!kernel::SelectKernelByObjectType(kernel_node, kernel_attrs, &object_selected_kernel_attrs)) {
667     return {false, kernel::KernelObjectTypeNotSupportWarning(kernel_node)};
668   }
669 
670   kernel::SetKernelObjectTypeWithSelectedAttr(kernel_node, object_selected_kernel_attrs[0]);
671   return {true, {}};
672 }
673 
SetKernelInfoWithMsg(const CNodePtr & kernel_node,KernelType kernel_type)674 std::pair<std::string, ExceptionType> SetKernelInfoWithMsg(const CNodePtr &kernel_node, KernelType kernel_type) {
675   MS_EXCEPTION_IF_NULL(kernel_node);
676   if (common::AnfAlgo::IsGraphKernel(kernel_node)) {
677     auto func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(kernel_node);
678     MS_EXCEPTION_IF_NULL(func_graph);
679     SetGraphKernelInfo(kernel_node, func_graph);
680     return {};
681   }
682   auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
683   AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
684   auto selected = GetSelectKernelObjectTypeResult(kernel_node, kernel_type);
685   if (!selected.first) {
686     return selected.second;
687   }
688   std::vector<std::string> inputs_format;
689   std::vector<TypeId> inputs_type;
690   size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
691   for (size_t input_index = 0; input_index < input_num; ++input_index) {
692     (void)inputs_format.emplace_back(kOpFormat_DEFAULT);
693     inputs_type.push_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
694   }
695 
696   std::vector<std::string> outputs_format;
697   std::vector<TypeId> outputs_type;
698   auto output_kernel_object_types = builder->Build()->GetAllOutputKernelObjectTypes();
699   if (output_kernel_object_types.size() == 1 && output_kernel_object_types[0] == kernel::KernelObjectType::TUPLE) {
700     outputs_type = {common::AnfAlgo::GetOutputInferDataType(kernel_node, 0)};
701     outputs_format = {kOpFormat_DEFAULT};
702   } else {
703     size_t output_num = kernel::GetOutputNum(kernel_node);
704     for (size_t output_index = 0; output_index < output_num; ++output_index) {
705       (void)outputs_format.emplace_back(kOpFormat_DEFAULT);
706       outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
707     }
708   }
709   std::string origin_data_format = kOpFormat_DEFAULT;
710   if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
711     UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
712   }
713   builder->SetOriginDataFormat(origin_data_format);
714   builder->SetInputsFormat(inputs_format);
715   builder->SetInputsDeviceType(inputs_type);
716   builder->SetOutputsFormat(outputs_format);
717   builder->SetOutputsDeviceType(outputs_type);
718   kernel::UnfoldKernelBuildInfo(kernel_node);
719   if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, kernel_node)) {
720     kernel::SetDynamicInputSizeAttr(kernel_node);
721   }
722   MS_LOG(INFO) << kernel_node->fullname_with_scope() << " kernel attr info: "
723                << kernel::FetchPrintInfoByKernelAttr(kernel::GetKernelAttrFromBuildInfo(builder->Build()));
724 
725   std::vector<std::tuple<size_t, TypeId, TypeId>> input_reduce_index;
726   bool result = GetSelectKernelResult(kernel_node, builder, &kernel_type, &input_reduce_index);
727   SetTensorDeviceInfo(*(builder->Build()), kernel_node, input_reduce_index);
728 
729   // Return the kernel select failure info.
730   if (common::AnfAlgo::HasNodeAttr(ops::kBatchRank, kernel_node) &&
731       !kVmapGPUWhiteList.count(common::AnfAlgo::GetCNodeName(kernel_node))) {
732     builder->SetKernelType(UNKNOWN_KERNEL_TYPE);
733     builder->SetProcessor(kernel::Processor::UNKNOWN);
734     std::stringstream ss;
735     ss << common::AnfAlgo::GetCNodeName(kernel_node)
736        << " does not support 'batch_rank' on GPU, which means that 'vmap' cannot support "
737        << common::AnfAlgo::GetCNodeName(kernel_node) << " on GPU currently.";
738     return {ss.str(), NotSupportError};
739   }
740   if (!result && (!common::AnfAlgo::IsBpropCutOpExecInBackend(kernel_node))) {
741     builder->SetKernelType(UNKNOWN_KERNEL_TYPE);
742     builder->SetProcessor(kernel::Processor::UNKNOWN);
743     return PrintUnsupportedTypeWarning(kernel_node, inputs_type, outputs_type, kernel_type);
744   }
745 
746   builder->SetKernelType(kernel_type);
747   builder->SetProcessor(kernel::Processor::CUDA);
748   return {};
749 }
750 
SetKernelInfo(const CNodePtr & kernel_node,KernelType kernel_type)751 void GPUGraphKernelInfo::SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
752   auto [msg, etype] = SetKernelInfoWithMsg(kernel_node, kernel_type);
753   if (msg.empty()) {
754     return;
755   }
756   MS_EXCEPTION(etype) << "#umsg#Kernel select failed:#umsg#" << msg;
757 }
758 }  // namespace gpu
759 }  // namespace device
760 }  // namespace mindspore
761