• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/kernel_query.h"
18 #include "backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h"
19 #include "backend/kernel_compiler/host/host_kernel_metadata.h"
20 #include "backend/kernel_compiler/rts/rt_kernel_info.h"
21 #include "backend/kernel_compiler/hccl/hccl_kernel_metadata.h"
22 #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h"
23 #include "backend/kernel_compiler/akg/akg_kernel_metadata.h"
24 #include "backend/session/anf_runtime_algorithm.h"
25 #include "utils/ms_context.h"
26 #include "utils/trace_base.h"
27 
28 namespace mindspore {
29 namespace kernel {
30 namespace {
FilterInvalidKernelInfo(const CNodePtr & kernel_node,std::vector<std::shared_ptr<kernel::KernelBuildInfo>> * kernel_info_list)31 void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
32                              std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
33   MS_EXCEPTION_IF_NULL(kernel_info_list);
34   if (kernel_info_list->empty()) {
35     return;
36   }
37   MS_EXCEPTION_IF_NULL(kernel_node);
38   size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(kernel_node);
39   size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node);
40   std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
41   (void)std::copy_if(
42     kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
43     [output_tensor_num, input_tensor_num](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
44       return kernel_build_info->GetOutputNum() == output_tensor_num &&
45              kernel_build_info->GetInputNum() == input_tensor_num;
46     });
47   if (!filtered_list.empty()) {
48     kernel_info_list->clear();
49     (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
50   } else {
51     for (size_t index = 0; index < kernel_info_list->size(); ++index) {
52       std::ostringstream buffer;
53       auto &kernel_info = kernel_info_list->at(index);
54       MS_EXCEPTION_IF_NULL(kernel_info);
55       if (kernel_info->GetOutputNum() != output_tensor_num) {
56         buffer << "Kernel node's output size [" << output_tensor_num << "]"
57                << " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]";
58       } else {
59         buffer << "Kernel node's input size [" << input_tensor_num << "]"
60                << " cannot match the kernel's input size [" << kernel_info->GetInputNum() << "]";
61       }
62       MS_LOG(INFO) << "Kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str();
63     }
64     kernel_info_list->clear();
65     MS_LOG(INFO) << "Node: " << kernel_node->DebugString() << "'s output size : [" << output_tensor_num << "]"
66                  << "input size : [" << input_tensor_num << "] can not match any kernelInfo !";
67   }
68 }
69 }  // namespace
70 
CheckKernelInfoListEmpty(const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> * kernel_info_list,const std::string & type)71 void CheckKernelInfoListEmpty(const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
72                               const std::string &type) {
73   MS_EXCEPTION_IF_NULL(kernel_info_list);
74   if (kernel_info_list->empty()) {
75     MS_LOG(INFO) << "Warning: kernel info list is empty, kernel type: " << type;
76   }
77 }
78 
KernelQueryAll(const CNodePtr & kernel_node,std::vector<std::shared_ptr<kernel::KernelBuildInfo>> * kernel_info_list)79 void KernelQueryAll(const CNodePtr &kernel_node,
80                     std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
81   MS_EXCEPTION_IF_NULL(kernel_node);
82   MS_EXCEPTION_IF_NULL(kernel_info_list);
83   TbeMetadataInfo(kernel_node, kernel_info_list);
84   if (kernel_info_list->empty()) {
85     GetRtKelInfo(kernel_node, kernel_info_list);
86     CheckKernelInfoListEmpty(kernel_info_list, "RT_Kernel");
87   }
88   if (kernel_info_list->empty()) {
89     HcclMetadataInfo(kernel_node, kernel_info_list);
90     CheckKernelInfoListEmpty(kernel_info_list, "HCCL_Kernel");
91   }
92   if (kernel_info_list->empty()) {
93     HostMetadataInfo(kernel_node, kernel_info_list);
94     CheckKernelInfoListEmpty(kernel_info_list, "HOST_Kernel");
95   }
96 }
97 
KernelQuery(const CNodePtr & kernel_node,std::vector<std::shared_ptr<kernel::KernelBuildInfo>> * kernel_info_list,KernelType kernel_type)98 void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
99                  KernelType kernel_type) {
100   MS_EXCEPTION_IF_NULL(kernel_node);
101   MS_EXCEPTION_IF_NULL(kernel_info_list);
102 
103   auto context_ptr = MsContext::GetInstance();
104   MS_EXCEPTION_IF_NULL(context_ptr);
105 
106   const PrimitivePtr kPrimProdForceSeA = std::make_shared<Primitive>("ProdForceSeA");
107   if (IsPrimitiveCNode(kernel_node, kPrimProdForceSeA)) {
108     kernel_type = KernelType::AKG_KERNEL;
109   }
110 
111   const PrimitivePtr kPrimLoadIm2Col = std::make_shared<Primitive>("LoadIm2Col");
112   if (IsPrimitiveCNode(kernel_node, kPrimLoadIm2Col)) {
113     kernel_type = KernelType::AKG_KERNEL;
114   }  // use LoadIm2Col only for THOR optimizer
115 
116   switch (kernel_type) {
117     case KernelType::AKG_KERNEL:
118       AkgMetadataInfo(kernel_node, kernel_info_list);
119       break;
120     default:
121       KernelQueryAll(kernel_node, kernel_info_list);
122       break;
123   }
124   // check output
125   FilterInvalidKernelInfo(kernel_node, kernel_info_list);
126 }
127 
AICPUQuery(const CNodePtr & kernel_node,std::vector<std::shared_ptr<kernel::KernelBuildInfo>> * kernel_info_list)128 void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
129   MS_EXCEPTION_IF_NULL(kernel_node);
130   MS_EXCEPTION_IF_NULL(kernel_info_list);
131   kernel_info_list->clear();
132   AicpuMetadataInfo(kernel_node, kernel_info_list);
133   FilterInvalidKernelInfo(kernel_node, kernel_info_list);
134 }
135 
IsSupportedByAICPU(const AnfNodePtr & kernel_node,const KernelBuildInfoPtr & select_kernel_build_info)136 bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
137   MS_EXCEPTION_IF_NULL(kernel_node);
138   MS_EXCEPTION_IF_NULL(select_kernel_build_info);
139   std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
140   auto cnode = kernel_node->cast<CNodePtr>();
141   MS_EXCEPTION_IF_NULL(cnode);
142   AICPUQuery(cnode, &kernel_info_list);
143   return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
144                      [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
145                        MS_EXCEPTION_IF_NULL(item);
146                        return item->IsSimilarityKernelBuildInfo(*select_kernel_build_info);
147                      });
148 }
149 
IsSupportedByAICore(const AnfNodePtr & kernel_node,const KernelBuildInfoPtr & select_kernel_build_info)150 bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
151   MS_EXCEPTION_IF_NULL(kernel_node);
152   MS_EXCEPTION_IF_NULL(select_kernel_build_info);
153   std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
154   auto cnode = kernel_node->cast<CNodePtr>();
155   MS_EXCEPTION_IF_NULL(cnode);
156   TbeMetadataInfo(cnode, &kernel_info_list);
157   return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
158                      [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
159                        MS_EXCEPTION_IF_NULL(item);
160                        return *item == *select_kernel_build_info;
161                      });
162 }
163 }  // namespace kernel
164 }  // namespace mindspore
165