• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
18 
19 #include <memory>
20 #include <set>
21 #include <string>
22 
23 #include "runtime/device/kernel_info.h"
24 #include "runtime/device/cpu/kernel_select_cpu.h"
25 
26 namespace mindspore {
27 namespace kernel {
28 namespace {
29 const std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose", "Unpack", "AddN"};
30 }  // namespace
31 
GetInstance()32 CPUKernelFactory &CPUKernelFactory::GetInstance() {
33   static CPUKernelFactory instance;
34   return instance;
35 }
36 
Register(const std::string & kernel_name,const KernelAttr & kernel_attr,CPUKernelCreator && kernel_creator)37 void CPUKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr,
38                                 CPUKernelCreator &&kernel_creator) {
39   (void)name_to_attr_creator_[kernel_name].emplace_back(kernel_attr, kernel_creator);
40 #if !defined(_WIN32) && !defined(_WIN64)
41   MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name;
42 #endif
43 }
44 
Create(const std::string & kernel_name,const CNodePtr & apply_kernel)45 std::shared_ptr<CPUKernel> CPUKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) {
46   MS_EXCEPTION_IF_NULL(apply_kernel);
47   auto kernel_info = dynamic_cast<device::KernelInfo *>(apply_kernel->kernel_info());
48   MS_EXCEPTION_IF_NULL(kernel_info);
49   const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info();
50   MS_EXCEPTION_IF_NULL(kernel_build_Info);
51   std::pair<bool, size_t> ret_pair = CPUKernelAttrCheck(kernel_name, *kernel_build_Info);
52   if (ret_pair.first) {
53     return (name_to_attr_creator_.find(kernel_name)->second)[ret_pair.second].second();
54   }
55   return nullptr;
56 }
57 
SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info,std::vector<KernelAttr> * kernel_attrs)58 void CPUKernelFactory::SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info,
59                                       std::vector<KernelAttr> *kernel_attrs) {
60   MS_EXCEPTION_IF_NULL(kernel_attrs);
61   MS_EXCEPTION_IF_NULL(op_info);
62   auto inputs_ptr = op_info->inputs_ptr();
63   auto outputs_ptr = op_info->outputs_ptr();
64   if (inputs_ptr.empty()) {
65     MS_LOG(EXCEPTION) << "op " << op_info->op_name() << " input size is zero.";
66   }
67   auto first_input_dtypes = inputs_ptr[0]->dtypes();
68   auto input_formats = inputs_ptr[0]->formats();
69 
70   for (size_t i = 0; i < first_input_dtypes.size(); i++) {
71     KernelAttr kernel_attr;
72     (void)kernel_attr.AddInputAttr(kernel::DtypeToTypeId(first_input_dtypes[i]), input_formats[i]);
73     for (size_t j = 1; j < inputs_ptr.size(); j++) {
74       auto input_dtypes = inputs_ptr[j]->dtypes();
75       input_formats = inputs_ptr[j]->formats();
76       (void)kernel_attr.AddInputAttr(kernel::DtypeToTypeId(input_dtypes[i]), input_formats[i]);
77     }
78     for (size_t j = 0; j < outputs_ptr.size(); j++) {
79       auto output_dtypes = outputs_ptr[j]->dtypes();
80       auto output_formats = outputs_ptr[j]->formats();
81       (void)kernel_attr.AddOutputAttr(kernel::DtypeToTypeId(output_dtypes[i]), output_formats[i]);
82     }
83     if (same_op_name.count(op_info->op_name()) != 0) {
84       (void)kernel_attr.SetAllSameAttr(true);
85     }
86     (void)kernel_attrs->emplace_back(kernel_attr);
87   }
88 }
89 
UpdateKernelAttrs(const std::string & kernel_name,const std::vector<KernelAttr> & kernel_attrs)90 void CPUKernelFactory::UpdateKernelAttrs(const std::string &kernel_name, const std::vector<KernelAttr> &kernel_attrs) {
91   size_t attr_size = kernel_attrs.size();
92   std::vector<std::pair<KernelAttr, CPUKernelCreator>> attr_creators(attr_size);
93   auto iter = name_to_attr_creator_.find(kernel_name);
94   if (iter == name_to_attr_creator_.end()) {
95     MS_LOG(EXCEPTION) << "CPUKernelFactory has not registered operator: " << kernel_name;
96   }
97 
98   if (attr_size <= iter->second.size()) {
99     for (size_t i = 0; i < attr_size; i++) {
100       auto creator = name_to_attr_creator_.find(kernel_name)->second[i].second;
101       attr_creators[i] = std::make_pair(kernel_attrs[i], creator);
102     }
103   } else {
104     MS_LOG(INFO) << "attr size is not equal creators size " << kernel_name << " attr_size = " << attr_size
105                  << " creator_size = " << iter->second.size();
106     auto single_creator = name_to_attr_creator_.find(kernel_name)->second[0].second;
107     for (size_t i = 0; i < attr_size; i++) {
108       attr_creators[i] = std::make_pair(kernel_attrs[i], single_creator);
109     }
110   }
111   name_to_attr_creator_[kernel_name] = attr_creators;
112 }
113 
CPUKernelAttrCheck(const std::string & kernel_name,const KernelBuildInfo & kernel_info)114 std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &kernel_name,
115                                                              const KernelBuildInfo &kernel_info) {
116   auto iter = name_to_attr_creator_.find(kernel_name);
117   if (iter == name_to_attr_creator_.end()) {
118     MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!";
119     return std::make_pair(false, 0);
120   }
121 
122   if (device::cpu::IsDynamicParamKernel(kernel_name)) {
123     return std::make_pair(true, 0);
124   }
125 
126   auto kernel_attrs = GetSupportedKernelAttrList(kernel_name);
127   if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
128     auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
129     if (op_info_ptr == nullptr) {
130       MS_LOG(EXCEPTION) << "Not find op[" << kernel_name << "] in cpu";
131     }
132     kernel_attrs.clear();
133     SetKernelAttrs(op_info_ptr, &kernel_attrs);
134     kernel::CPUKernelFactory::GetInstance().UpdateKernelAttrs(kernel_name, kernel_attrs);
135   }
136   for (size_t index = 0; index < kernel_attrs.size(); ++index) {
137     if (CPUKernelSingleAttrCheck(kernel_attrs[index], kernel_info)) {
138       return std::make_pair(true, index);
139     }
140   }
141   return std::make_pair(false, 0);
142 }
143 
CPUKernelSingleAttrCheck(const KernelAttr & kernel_attr,const KernelBuildInfo & kernel_info) const144 bool CPUKernelFactory::CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr,
145                                                 const KernelBuildInfo &kernel_info) const {
146   for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) {
147     auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetInputAttr(0).first : kernel_attr.GetInputAttr(i).first;
148     if (kernel_info.GetInputDeviceType(i) != dtype) {
149       MS_LOG(DEBUG) << "input index:" << i << ", kernel info type:" << kernel_info.GetInputDeviceType(i)
150                     << ", register type:" << dtype;
151       return false;
152     }
153   }
154   for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) {
155     auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetOutputAttr(0).first : kernel_attr.GetOutputAttr(i).first;
156     if (kernel_info.GetOutputDeviceType(i) != dtype) {
157       MS_LOG(DEBUG) << "output index:" << i << ", kernel info type:" << kernel_info.GetOutputDeviceType(i)
158                     << ", register type:" << dtype;
159       return false;
160     }
161   }
162   return true;
163 }
164 
GetSupportedKernelAttrList(const std::string & kernel_name)165 std::vector<KernelAttr> CPUKernelFactory::GetSupportedKernelAttrList(const std::string &kernel_name) {
166   std::vector<KernelAttr> result;
167   auto iter = name_to_attr_creator_.find(kernel_name);
168   if (iter == name_to_attr_creator_.end()) {
169     MS_LOG(EXCEPTION) << "Not registered CPU kernel: op[" << kernel_name << "]!";
170   }
171   auto creators = iter->second;
172   result.reserve(creators.size());
173   for (size_t index = 0; index < creators.size(); ++index) {
174     auto attr_creator = creators[index];
175     (void)result.emplace_back(attr_creator.first);
176   }
177   return result;
178 }
179 }  // namespace kernel
180 }  // namespace mindspore
181