1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/registry/register_kernel_impl.h"
17 #include "include/registry/register_kernel.h"
18 #include "include/errorcode.h"
19 #include "src/common/version_manager.h"
20 #include "src/common/log_adapter.h"
21
22 using mindspore::registry::CreateKernel;
23 using mindspore::registry::KernelDesc;
24 using mindspore::schema::PrimitiveType_MAX;
25 using mindspore::schema::PrimitiveType_MIN;
26 namespace mindspore::registry {
27 namespace {
28 static const auto kOpTypeLen = PrimitiveType_MAX - PrimitiveType_MIN + 1;
29 static const auto kDataTypeLen =
30 static_cast<int>(DataType::kNumberTypeEnd) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
31 static const auto kKernelMaxNum = kOpTypeLen * kDataTypeLen;
32 static constexpr auto kMaxProviderNum = 10;
33 static constexpr auto kMaxArchPerProviderNum = 10;
34 static constexpr auto kMaxCustomTypeNum = 200;
GetFuncIndex(const KernelDesc & desc)35 int GetFuncIndex(const KernelDesc &desc) {
36 if (desc.data_type >= DataType::kNumberTypeEnd) {
37 return -1;
38 }
39 int data_type_index = static_cast<int>(desc.data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
40 if (data_type_index < 0) {
41 return -1;
42 }
43 int index = data_type_index * kOpTypeLen + desc.type;
44 if (index >= kKernelMaxNum) {
45 return -1;
46 }
47 return index;
48 }
49 } // namespace
50
RegCustomKernel(const std::string & arch,const std::string & provider,DataType data_type,const std::string & type,const CreateKernel creator)51 Status RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
52 const std::string &type, const CreateKernel creator) {
53 int data_type_index = static_cast<int>(data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
54 if (data_type_index < 0 || data_type_index >= kDataTypeLen) {
55 MS_LOG(ERROR) << "invalid data_type: " << static_cast<int>(data_type) << "!provider: " << provider;
56 return kLiteError;
57 }
58 std::unique_lock<std::mutex> lock(lock_);
59 auto provider_iter = custom_kernel_creators_.find(provider);
60 if (provider_iter == custom_kernel_creators_.end() && custom_kernel_creators_.size() >= kMaxProviderNum) {
61 MS_LOG(ERROR) << "register too many provider!";
62 return kLiteError;
63 }
64 if (provider_iter != custom_kernel_creators_.end()) {
65 auto arch_iter = provider_iter->second.find(arch);
66 if (arch_iter == provider_iter->second.end()) {
67 if (provider_iter->second.size() >= kMaxArchPerProviderNum) {
68 MS_LOG(ERROR) << "register too many arch!";
69 return kLiteError;
70 }
71 } else {
72 auto type_iter = arch_iter->second.find(type);
73 if (type_iter == arch_iter->second.end() && arch_iter->second.size() >= kMaxCustomTypeNum) {
74 MS_LOG(ERROR) << "register too many type!";
75 return kLiteError;
76 }
77 }
78 }
79 if (custom_kernel_creators_[provider][arch][type] == nullptr) {
80 custom_kernel_creators_[provider][arch][type] =
81 reinterpret_cast<CreateKernel *>(calloc(kDataTypeLen, sizeof(CreateKernel)));
82 if (custom_kernel_creators_[provider][arch][type] == nullptr) {
83 MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch;
84 return kLiteError;
85 }
86 }
87
88 custom_kernel_creators_[provider][arch][type][data_type_index] = creator;
89 return kSuccess;
90 }
91
RegKernel(const std::string & arch,const std::string & provider,DataType data_type,int type,const registry::CreateKernel creator)92 Status RegistryKernelImpl::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
93 const registry::CreateKernel creator) {
94 if (type <= static_cast<int>(PrimitiveType_MIN) || type > static_cast<int>(PrimitiveType_MAX)) {
95 MS_LOG(ERROR) << "Invalid op type : " << type;
96 return kLiteParamInvalid;
97 }
98 KernelDesc desc = {data_type, type, arch, provider};
99 int index = GetFuncIndex(desc);
100 if (index < 0) {
101 MS_LOG(ERROR) << "invalid kernel key, arch " << arch << ", data_type" << static_cast<int>(data_type) << ",op type "
102 << type;
103 return kLiteError;
104 }
105 std::unique_lock<std::mutex> lock(lock_);
106 auto iter = kernel_creators_.find(provider);
107 if (iter == kernel_creators_.end()) {
108 if (kernel_creators_.size() >= kMaxProviderNum) {
109 MS_LOG(ERROR) << "register too many provider!";
110 return kLiteError;
111 }
112 kernel_creators_[provider][arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
113 if (kernel_creators_[provider][arch] == nullptr) {
114 MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
115 return kLiteError;
116 }
117 } else {
118 auto iter_arch = iter->second.find(arch);
119 if (iter_arch == iter->second.end()) {
120 if (iter->second.size() >= kMaxArchPerProviderNum) {
121 MS_LOG(ERROR) << "register too many arch!";
122 return kLiteError;
123 }
124 iter->second[arch] = reinterpret_cast<CreateKernel *>(calloc(kKernelMaxNum, sizeof(CreateKernel)));
125 if (iter->second[arch] == nullptr) {
126 MS_LOG(ERROR) << "malloc kernel creator buffer fail! provider: " << provider << ",arch:" << arch;
127 return kLiteError;
128 }
129 }
130 }
131
132 kernel_creators_[provider][arch][index] = creator;
133 return kSuccess;
134 }
135
GetCustomKernelCreator(const schema::Primitive * primitive,KernelDesc * desc)136 registry::CreateKernel RegistryKernelImpl::GetCustomKernelCreator(const schema::Primitive *primitive,
137 KernelDesc *desc) {
138 int data_type_index = static_cast<int>(desc->data_type) - static_cast<int>(DataType::kNumberTypeBegin) - 1;
139 if (data_type_index < 0 || desc->data_type >= DataType::kNumberTypeEnd) {
140 return nullptr;
141 }
142 auto param = primitive->value_as_Custom();
143 if (param == nullptr || param->type() == nullptr) {
144 return nullptr;
145 }
146 auto custom_type = param->type()->str();
147 if (!desc->provider.empty() && !desc->arch.empty()) {
148 auto creator_buf = custom_kernel_creators_[desc->provider][desc->arch][custom_type];
149 if (creator_buf != nullptr && creator_buf[data_type_index] != nullptr) {
150 return creator_buf[data_type_index];
151 }
152 return nullptr;
153 }
154 for (auto &&providers : custom_kernel_creators_) {
155 auto archs = providers.second;
156 auto archs_iter = std::find_if(archs.begin(), archs.end(), [custom_type, data_type_index](auto &&item) {
157 return item.second[custom_type] != nullptr && item.second[custom_type][data_type_index] != nullptr;
158 });
159 if (archs_iter != archs.end()) {
160 desc->arch = archs_iter->first;
161 return archs_iter->second[custom_type][data_type_index];
162 }
163 }
164
165 return nullptr;
166 }
167
GetProviderCreator(const schema::Primitive * primitive,KernelDesc * desc)168 registry::CreateKernel RegistryKernelImpl::GetProviderCreator(const schema::Primitive *primitive, KernelDesc *desc) {
169 registry::CreateKernel creator = nullptr;
170 std::unique_lock<std::mutex> lock(lock_);
171 if (desc->type == schema::PrimitiveType_Custom) {
172 return GetCustomKernelCreator(primitive, desc);
173 }
174
175 auto index = GetFuncIndex(*desc);
176 if (index < 0) {
177 return nullptr;
178 }
179 for (auto &&item : kernel_creators_) {
180 if (item.first != desc->provider) {
181 continue;
182 }
183 for (auto &&arch_item : item.second) {
184 if (arch_item.first != desc->arch) {
185 continue;
186 }
187 creator = arch_item.second[index];
188 if (creator != nullptr) {
189 break;
190 }
191 }
192 if (creator != nullptr) {
193 break;
194 }
195 }
196 return creator;
197 }
198
~RegistryKernelImpl()199 RegistryKernelImpl::~RegistryKernelImpl() {
200 for (auto &&item : kernel_creators_) {
201 for (auto &&creator : item.second) {
202 free(creator.second);
203 creator.second = nullptr;
204 }
205 }
206 for (auto &&provider : custom_kernel_creators_) {
207 for (auto &&arch : provider.second) {
208 for (auto &&creator : arch.second) {
209 free(creator.second);
210 creator.second = nullptr;
211 }
212 }
213 }
214 }
215 } // namespace mindspore::registry
216