• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/litert/kernel_registry.h"
17 #include <utility>
18 #include <memory>
19 #include "include/errorcode.h"
20 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
21 #include "include/registry/register_kernel.h"
22 #endif
23 #include "src/common/ops/populate/populate_register.h"
24 #include "src/common/version_manager.h"
25 #include "nnacl/pooling_parameter.h"
26 #if defined(ENABLE_FP16) && defined(ENABLE_ARM)
27 #if defined(__ANDROID__)
28 #include <asm/hwcap.h>
29 #endif
30 #include "common/utils.h"
31 #include "src/common/log_adapter.h"
32 #include "src/common/utils.h"
33 #endif
34 #include "src/common/tensor_util.h"
35 #include "src/litert/kernel/cpu/nnacl/nnacl_manager.h"
36 #ifdef ENABLE_BOLT
37 #include "src/litert/kernel/cpu/bolt/bolt_kernel_manager.h"
38 #endif
39 
40 using mindspore::kernel::kBuiltin;
41 using mindspore::kernel::kCPU;
42 using mindspore::kernel::KERNEL_ARCH;
43 using mindspore::kernel::KernelKey;
44 
45 namespace mindspore::lite {
CreatorArraysInit()46 void KernelRegistry::CreatorArraysInit() {
47   std::unique_lock<std::mutex> malloc_creator_array(lock_);
48   if (creator_arrays_ == nullptr) {
49     creator_arrays_ = reinterpret_cast<kernel::KernelCreator *>(malloc(array_size_ * sizeof(kernel::KernelCreator)));
50     if (creator_arrays_ != nullptr) {
51       memset(creator_arrays_, 0, array_size_ * sizeof(kernel::KernelCreator));
52     }
53   }
54   if (inner_op_creator_arrays_ == nullptr) {
55     inner_op_creator_arrays_ =
56       reinterpret_cast<kernel::KernelCreator *>(malloc(inner_op_array_size_ * sizeof(kernel::KernelCreator)));
57     if (inner_op_creator_arrays_ != nullptr) {
58       memset(inner_op_creator_arrays_, 0, inner_op_array_size_ * sizeof(kernel::KernelCreator));
59     }
60   }
61   return;
62 }
63 
GetInstance()64 KernelRegistry *KernelRegistry::GetInstance() {
65   static KernelRegistry instance;
66   return &instance;
67 }
68 
GetCreator(const KernelKey & desc)69 kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
70   if (desc.format != NHWC) {
71     /* nchw kernel using nnacl kernel */
72     return nullptr;
73   }
74 
75   if (desc.provider == kBuiltin) {
76     int index = GetCreatorFuncIndex(desc);
77     if (desc.type >= PrimType_MIN && desc.type < PrimType_MAX) {
78       if (index >= array_size_ || index < 0) {
79         MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type " << desc.data_type << ",op type "
80                       << desc.type;
81         return nullptr;
82       }
83       if (creator_arrays_ != nullptr) {
84         return creator_arrays_[index];
85       }
86     } else if (desc.type >= PrimType_InnerOpMin && desc.type < PrimType_InnerOpMax) {
87       MS_CHECK_TRUE_RET(index >= 0 && index < inner_op_array_size_, nullptr);
88       if (inner_op_creator_arrays_ != nullptr) {
89         return inner_op_creator_arrays_[index];
90       }
91     }
92   }
93   MS_LOG(ERROR) << "Call wrong interface!provider: " << desc.provider;
94   return nullptr;
95 }
96 
GetCreatorFuncIndex(const kernel::KernelKey desc)97 int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) {
98   int device_index = static_cast<int>(desc.arch) - kKernelArch_MIN;
99   int dType_index = desc.data_type == kObjectTypeString ? 0 : static_cast<int>(desc.data_type) - kNumberTypeBegin;
100   int op_index = static_cast<int>(desc.type);
101   int op_type_length = op_type_length_;
102   if (op_index >= PrimType_InnerOpMin && desc.type < PrimType_InnerOpMax) {
103     op_type_length = inner_op_type_length_;
104     op_index -= PrimType_InnerOpMin;
105   }
106   int index = device_index * data_type_length_ * op_type_length + dType_index * op_type_length + op_index;
107   return index;
108 }
109 
RegKernel(const KernelKey desc,const kernel::KernelCreator creator)110 void KernelRegistry::RegKernel(const KernelKey desc, const kernel::KernelCreator creator) {
111   CreatorArraysInit();
112   int index = GetCreatorFuncIndex(desc);
113   if (desc.type >= PrimType_MIN && desc.type < PrimType_MAX) {
114     if (index >= array_size_ || index < 0) {
115       MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
116                     << desc.type;
117       return;
118     }
119     if (creator_arrays_ != nullptr) {
120       creator_arrays_[index] = creator;
121     }
122   } else if (desc.type >= PrimType_InnerOpMin && desc.type < PrimType_InnerOpMax) {
123     MS_CHECK_TRUE_RET_VOID(index >= 0 && index < inner_op_array_size_);
124     if (inner_op_creator_arrays_ != nullptr) {
125       inner_op_creator_arrays_[index] = creator;
126     }
127   }
128 }
129 
RegKernel(KERNEL_ARCH arch,TypeId data_type,int op_type,kernel::KernelCreator creator)130 void KernelRegistry::RegKernel(KERNEL_ARCH arch, TypeId data_type, int op_type, kernel::KernelCreator creator) {
131   CreatorArraysInit();
132   KernelKey desc = {arch, data_type, NHWC, op_type};
133   int index = GetCreatorFuncIndex(desc);
134   if (desc.type >= PrimType_MIN && desc.type < PrimType_MAX) {
135     if (index >= array_size_ || index < 0) {
136       MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
137                     << desc.type;
138       return;
139     }
140     if (creator_arrays_ != nullptr) {
141       creator_arrays_[index] = creator;
142     }
143   } else if (desc.type >= PrimType_InnerOpMin && desc.type < PrimType_InnerOpMax) {
144     MS_CHECK_TRUE_RET_VOID(index >= 0 && index < inner_op_array_size_);
145     if (inner_op_creator_arrays_ != nullptr) {
146       inner_op_creator_arrays_[index] = creator;
147     }
148   }
149 }
150 
~KernelRegistry()151 KernelRegistry::~KernelRegistry() {
152   KernelRegistry *instance = GetInstance();
153   std::unique_lock<std::mutex> malloc_creator_array(instance->lock_);
154   if (instance->creator_arrays_ != nullptr) {
155     free(instance->creator_arrays_);
156     instance->creator_arrays_ = nullptr;
157   }
158   if (instance->inner_op_creator_arrays_ != nullptr) {
159     free(instance->inner_op_creator_arrays_);
160     instance->inner_op_creator_arrays_ = nullptr;
161   }
162 }
163 
SupportKernel(const KernelKey & key)164 bool KernelRegistry::SupportKernel(const KernelKey &key) {
165   auto kernel_creator = GetCreator(key);
166   if (kernel_creator != nullptr) {
167     return true;
168   }
169   return nnacl::NNACLSupportKernel(key.type, key.data_type);
170 }
171 
GetCustomKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,const mindspore::Context * ms_ctx,const kernel::KernelKey & key,kernel::KernelExec ** kernel,const void * primitive)172 int KernelRegistry::GetCustomKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
173                                     const mindspore::Context *ms_ctx, const kernel::KernelKey &key,
174                                     kernel::KernelExec **kernel, const void *primitive) {
175 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
176   MS_ASSERT(ms_ctx != nullptr);
177   MS_ASSERT(kernel != nullptr);
178   registry::KernelDesc desc{static_cast<DataType>(key.data_type), key.type, key.kernel_arch, key.provider};
179   auto creator = registry::RegisterKernel::GetCreator(static_cast<const schema::Primitive *>(primitive), &desc);
180   if (creator == nullptr) {
181     return RET_NOT_SUPPORT;
182   }
183 
184   auto base_kernel = creator(LiteTensorsToMSTensors(in_tensors), LiteTensorsToMSTensors(out_tensors),
185                              static_cast<const schema::Primitive *>(primitive), ms_ctx);
186   if (base_kernel != nullptr) {
187     auto *kernel_exec = new (std::nothrow) kernel::KernelExec(base_kernel);
188     if (kernel_exec != nullptr) {
189       constexpr auto kArchCPU = "CPU";
190       constexpr auto kArchGPU = "GPU";
191       kernel::KernelKey tmp_key = key;
192       if (desc.arch == kArchCPU) {
193         tmp_key.arch = kernel::kCPU;
194       } else if (desc.arch == kArchGPU) {
195         tmp_key.arch = kernel::kGPU;
196       } else {
197         tmp_key.arch = kernel::kCustom;
198       }
199       kernel_exec->set_desc(tmp_key);
200       *kernel = kernel_exec;
201       return RET_OK;
202     }
203   }
204 #endif
205   return RET_ERROR;
206 }
207 
GetLiteKernel(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,const InnerContext * ctx,kernel::KernelKey * key,OpParameter * parameter)208 kernel::LiteKernel *KernelRegistry::GetLiteKernel(const std::vector<Tensor *> &in_tensors,
209                                                   const std::vector<Tensor *> &out_tensors, const InnerContext *ctx,
210                                                   kernel::KernelKey *key, OpParameter *parameter) {
211 #ifdef ENABLE_BOLT
212   if (key->arch == KERNEL_ARCH::kCPU) {
213     auto *bolt_kernel = kernel::bolt::BoltKernelRegistry(parameter, in_tensors, out_tensors, ctx, key);
214     if (bolt_kernel == nullptr) {
215       MS_LOG(DEBUG) << "Registry bolt kernel failed: " << parameter->name_;
216     } else {
217       bolt_kernel->set_registry_data_type(key->data_type);
218       return bolt_kernel;
219     }
220   }
221 #endif
222 
223   auto creator = GetCreator(*key);
224   if (creator != nullptr) {
225     auto lite_kernel = creator(in_tensors, out_tensors, parameter, ctx, *key);
226     if (lite_kernel != nullptr) {
227       lite_kernel->set_registry_data_type(key->data_type);
228       return lite_kernel;
229     }
230     return nullptr;
231   }
232   if (key->arch != KERNEL_ARCH::kCPU) {
233     return nullptr;
234   }
235 
236   auto *lite_kernel = nnacl::NNACLKernelRegistry(parameter, in_tensors, out_tensors, ctx, *key);
237   if (lite_kernel == nullptr) {
238     MS_LOG(WARNING) << "Registry cpu kernel failed:  " << parameter->name_;
239     return nullptr;
240   }
241   lite_kernel->set_registry_data_type(key->data_type);
242   return lite_kernel;
243 }
244 
GetKernelExec(const std::vector<Tensor * > & in_tensors,const std::vector<Tensor * > & out_tensors,const InnerContext * ctx,const mindspore::Context * ms_ctx,const kernel::KernelKey & key,OpParameter * parameter,kernel::KernelExec ** kernel,const void * primitive)245 int KernelRegistry::GetKernelExec(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
246                                   const InnerContext *ctx, const mindspore::Context *ms_ctx,
247                                   const kernel::KernelKey &key, OpParameter *parameter, kernel::KernelExec **kernel,
248                                   const void *primitive) {
249   CHECK_NULL_RETURN(kernel);
250 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
251   if (key.provider != kBuiltin) {
252     CHECK_NULL_RETURN(ms_ctx);
253     auto ret = GetCustomKernel(in_tensors, out_tensors, ms_ctx, key, kernel, primitive);
254     if (ret == RET_OK) {
255       (*kernel)->set_context(ctx);
256     }
257     return ret;
258   }
259 #endif
260 
261   CHECK_NULL_RETURN(ctx);
262   auto modify_key = key;
263   auto lite_kernel = GetLiteKernel(in_tensors, out_tensors, ctx, &modify_key, parameter);
264   if (lite_kernel != nullptr) {
265     std::shared_ptr<kernel::Kernel> shared_kernel(lite_kernel);
266     auto *kernel_exec = new (std::nothrow) kernel::KernelExec(shared_kernel);
267     if (kernel_exec != nullptr) {
268       kernel_exec->set_desc(modify_key);
269       kernel_exec->set_context(ctx);
270       *kernel = kernel_exec;
271       return RET_OK;
272     }
273   }
274   MS_LOG(WARNING) << "common cpu kernel registry failed";
275   return RET_ERROR;
276 }
277 }  // namespace mindspore::lite
278