• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/registry/register_kernel_impl.h"
17 #include "include/registry/register_kernel.h"
18 #include "include/errorcode.h"
19 #include "src/common/version_manager.h"
20 #include "src/common/log_adapter.h"
21 
22 using mindspore::registry::CreateKernel;
23 using mindspore::registry::KernelDesc;
24 using mindspore::schema::PrimitiveType_MAX;
25 using mindspore::schema::PrimitiveType_MIN;
26 namespace mindspore::registry {
27 namespace {
28 static const auto kOpTypeLen = PrimitiveType_MAX - PrimitiveType_MIN + 1;
29 static const auto kDataTypeLen =
30   static_cast<int>(DataType::kNumberTypeEnd) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
31 static const auto kKernelMaxNum = kOpTypeLen * kDataTypeLen;
32 static constexpr auto kMaxProviderNum = 10;
33 static constexpr auto kMaxArchPerProviderNum = 10;
34 static constexpr auto kMaxCustomTypeNum = 200;
GetFuncIndex(const KernelDesc & desc)35 int GetFuncIndex(const KernelDesc &desc) {
36   if (desc.data_type >= DataType::kNumberTypeEnd) {
37     return -1;
38   }
39   int data_type_index = static_cast<int>(desc.data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
40   if (data_type_index < 0) {
41     return -1;
42   }
43   int index = data_type_index * kOpTypeLen + desc.type;
44   if (index >= kKernelMaxNum) {
45     return -1;
46   }
47   return index;
48 }
49 }  // namespace
50 
RegCustomKernel(const std::string & arch,const std::string & provider,DataType data_type,const std::string & type,const CreateKernel creator)51 Status RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
52                                            const std::string &type, const CreateKernel creator) {
53   int data_type_index = static_cast<int>(data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
54   if (data_type_index < 0 || data_type_index >= kDataTypeLen) {
55     MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider;
56     return kLiteError;
57   }
58   std::unique_lock<std::mutex> lock(lock_);
59   auto provider_iter = custom_kernel_creators_.find(provider);
60   if (provider_iter == custom_kernel_creators_.end() && custom_kernel_creators_.size() >= kMaxProviderNum) {
61     MS_LOG(ERROR) << "register too many provider!";
62     return kLiteError;
63   }
64   if (provider_iter != custom_kernel_creators_.end()) {
65     auto arch_iter = provider_iter->second.find(arch);
66     if (arch_iter == provider_iter->second.end()) {
67       if (provider_iter->second.size() >= kMaxArchPerProviderNum) {
68         MS_LOG(ERROR) << "register too many arch!";
69         return kLiteError;
70       }
71     } else {
72       auto type_iter = arch_iter->second.find(type);
73       if (type_iter == arch_iter->second.end() && arch_iter->second.size() >= kMaxCustomTypeNum) {
74         MS_LOG(ERROR) << "register too many type!";
75         return kLiteError;
76       }
77     }
78   }
79   if (custom_kernel_creators_[provider][arch][type] == nullptr) {
80     custom_kernel_creators_[provider][arch][type] =
81       reinterpret_cast<CreateKernel *>(calloc(kDataTypeLen, sizeof(CreateKernel)));
82     if (custom_kernel_creators_[provider][arch][type] == nullptr) {
83       MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch;
84       return kLiteError;
85     }
86   }
87 
88   custom_kernel_creators_[provider][arch][type][data_type_index] = creator;
89   return kSuccess;
90 }
91 
RegKernel(const std::string & arch,const std::string & provider,DataType data_type,int type,const registry::CreateKernel creator)92 Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
93                                      const registry::CreateKernel creator) {
94   if (type <= static_cast<int>(PrimitiveType_MIN) || type > static_cast<int>(PrimitiveType_MAX)) {
95     MS_LOG(ERROR) << "Invalid op type : " << type;
96     return kLiteParamInvalid;
97   }
98   KernelDesc desc = {data_type, type, arch, provider};
99   int index = GetFuncIndex(desc);
100   if (index < 0) {
101     MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << static_cast<int>(data_type) << ",op type "
102                   << type;
103     return kLiteError;
104   }
105   std::unique_lock<std::mutex> lock(lock_);
106   auto iter = kernel_creators_.find(provider);
107   if (iter == kernel_creators_.end()) {
108     if (kernel_creators_.size() >= kMaxProviderNum) {
109       MS_LOG(ERROR) << "register too many provider!";
110       return kLiteError;
111     }
112     kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
113     if (kernel_creators_[provider][arch] == nullptr) {
114       MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
115       return kLiteError;
116     }
117   } else {
118     auto iter_arch = iter->second.find(arch);
119     if (iter_arch == iter->second.end()) {
120       if (iter->second.size() >= kMaxArchPerProviderNum) {
121         MS_LOG(ERROR) << "register too many arch!";
122         return kLiteError;
123       }
124       iter->second[arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
125       if (iter->second[arch] == nullptr) {
126         MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
127         return kLiteError;
128       }
129     }
130   }
131 
132   kernel_creators_[provider][arch][index] = creator;
133   return kSuccess;
134 }
135 
GetCustomKernelCreator(const schema::Primitive * primitive,KernelDesc * desc)136 registry::CreateKernel RegistryKernelImpl::GetCustomKernelCreator(const schema::Primitive *primitive,
137                                                                   KernelDesc *desc) {
138   int data_type_index = static_cast<int>(desc->data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
139   if (data_type_index < 0 || desc->data_type >= DataType::kNumberTypeEnd) {
140     return nullptr;
141   }
142   auto param = primitive->value_as_Custom();
143   if (param == nullptr || param->type() == nullptr) {
144     return nullptr;
145   }
146   auto custom_type = param->type()->str();
147   if (!desc->provider.empty() && !desc->arch.empty()) {
148     auto creator_buf = custom_kernel_creators_[desc->provider][desc->arch][custom_type];
149     if (creator_buf != nullptr && creator_buf[data_type_index] != nullptr) {
150       return creator_buf[data_type_index];
151     }
152     return nullptr;
153   }
154   for (auto &&providers : custom_kernel_creators_) {
155     auto archs = providers.second;
156     auto archs_iter = std::find_if(archs.begin(), archs.end(), [custom_type, data_type_index](auto &&item) {
157       return item.second[custom_type] != nullptr && item.second[custom_type][data_type_index] != nullptr;
158     });
159     if (archs_iter != archs.end()) {
160       desc->arch = archs_iter->first;
161       return archs_iter->second[custom_type][data_type_index];
162     }
163   }
164 
165   return nullptr;
166 }
167 
GetProviderCreator(const schema::Primitive * primitive,KernelDesc * desc)168 registry::CreateKernel RegistryKernelImpl::GetProviderCreator(const schema::Primitive *primitive, KernelDesc *desc) {
169   registry::CreateKernel creator = nullptr;
170   std::unique_lock<std::mutex> lock(lock_);
171   if (desc->type == schema::PrimitiveType_Custom) {
172     return GetCustomKernelCreator(primitive, desc);
173   }
174 
175   auto index = GetFuncIndex(*desc);
176   if (index < 0) {
177     return nullptr;
178   }
179   for (auto &&item : kernel_creators_) {
180     if (item.first != desc->provider) {
181       continue;
182     }
183     for (auto &&arch_item : item.second) {
184       if (arch_item.first != desc->arch) {
185         continue;
186       }
187       creator = arch_item.second[index];
188       if (creator != nullptr) {
189         break;
190       }
191     }
192     if (creator != nullptr) {
193       break;
194     }
195   }
196   return creator;
197 }
198 
~RegistryKernelImpl()199 RegistryKernelImpl::~RegistryKernelImpl() {
200   for (auto &&item : kernel_creators_) {
201     for (auto &&creator : item.second) {
202       free(creator.second);
203       creator.second = nullptr;
204     }
205   }
206   for (auto &&provider : custom_kernel_creators_) {
207     for (auto &&arch : provider.second) {
208       for (auto &&creator : arch.second) {
209         free(creator.second);
210         creator.second = nullptr;
211       }
212     }
213   }
214 }
215 }  // namespace mindspore::registry
216