1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/registry/kernel_interface_registry.h"
17 #include <memory>
18 #include "include/registry/register_kernel_interface.h"
19 #include "include/errorcode.h"
20 #include "src/common/log_adapter.h"
21 #include "src/common/version_manager.h"
22 #include "schema/model_generated.h"
23 #include "include/api/kernel.h"
24
25 using mindspore::registry::KernelInterfaceCreator;
26 using mindspore::schema::PrimitiveType_MAX;
27 using mindspore::schema::PrimitiveType_MIN;
28 namespace mindspore {
29 namespace registry {
30 namespace {
31 static constexpr auto kMaxProviderNum = 10;
32 static constexpr auto KMaxCustomTypeNum = 200;
33 static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1;
GetCustomType(const schema::Primitive * primitive)34 std::string GetCustomType(const schema::Primitive *primitive) {
35 auto param = primitive->value_as_Custom();
36 if (param == nullptr || param->type() == nullptr) {
37 return "";
38 }
39
40 return param->type()->str();
41 }
42 } // namespace
43
CustomReg(const std::string & provider,const std::string & type,const KernelInterfaceCreator creator)44 Status KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
45 const KernelInterfaceCreator creator) {
46 if (creator == nullptr) {
47 MS_LOG(ERROR) << "custom kernel interface creator is nullptr.";
48 return kLiteNullptr;
49 }
50 auto provider_iter = custom_creators_.find(provider);
51 if (provider_iter == custom_creators_.end() && custom_creators_.size() >= kMaxProviderNum) {
52 MS_LOG(ERROR) << "register too many provider!";
53 return kLiteError;
54 }
55 if (provider_iter != custom_creators_.end()) {
56 auto type_iter = provider_iter->second.find(type);
57 if (type_iter == provider_iter->second.end() && provider_iter->second.size() >= KMaxCustomTypeNum) {
58 MS_LOG(ERROR) << "register too many custom type!";
59 return kLiteError;
60 }
61 }
62 custom_creators_[provider][type] = creator;
63 return kSuccess;
64 }
65
GetCacheInterface(const std::string & provider,int op_type)66 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCacheInterface(const std::string &provider,
67 int op_type) {
68 if (provider.empty()) {
69 return nullptr;
70 }
71 auto provider_iter = kernel_interfaces_.find(provider);
72 if (provider_iter != kernel_interfaces_.end()) {
73 auto kernel_iter = provider_iter->second.find(op_type);
74 if (kernel_iter != provider_iter->second.end()) {
75 return kernel_iter->second;
76 }
77 }
78 return nullptr;
79 }
80
GetCustomCacheInterface(const std::string & provider,const std::string & type)81 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomCacheInterface(const std::string &provider,
82 const std::string &type) {
83 if (provider.empty()) {
84 return nullptr;
85 }
86 auto provider_iter = custom_kernels_.find(provider);
87 if (provider_iter == custom_kernels_.end()) {
88 return nullptr;
89 }
90 auto kernel_iter = provider_iter->second.find(type);
91 if (kernel_iter != provider_iter->second.end()) {
92 return kernel_iter->second;
93 }
94 return nullptr;
95 }
96
GetCustomKernelInterface(const schema::Primitive * primitive,const kernel::Kernel * kernel)97 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKernelInterface(
98 const schema::Primitive *primitive, const kernel::Kernel *kernel) {
99 std::unique_lock<std::mutex> lock(mutex_);
100 std::string type;
101 if (kernel == nullptr) {
102 type = GetCustomType(primitive);
103 } else {
104 type = kernel->GetAttr("type");
105 }
106 for (auto &&item : custom_creators_) {
107 auto &&provider = item.first;
108 auto kernel_interface = GetCustomCacheInterface(provider, type);
109 if (kernel_interface != nullptr) {
110 return kernel_interface;
111 }
112 auto provider_iter = custom_creators_.find(provider);
113 if (provider_iter == custom_creators_.end()) {
114 return nullptr;
115 }
116 auto creator_iter = provider_iter->second.find(type);
117 if (creator_iter != provider_iter->second.end()) {
118 if (creator_iter->second == nullptr) {
119 MS_LOG(ERROR) << "custom kernel interface creator is nullptr.";
120 return nullptr;
121 }
122 kernel_interface = creator_iter->second();
123 custom_kernels_[provider][type] = kernel_interface;
124 return kernel_interface;
125 }
126 }
127
128 return nullptr;
129 }
130
GetKernelInterface(const std::string & provider,const schema::Primitive * primitive,const kernel::Kernel * kernel)131 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(const std::string &provider,
132 const schema::Primitive *primitive,
133 const kernel::Kernel *kernel) {
134 if (primitive == nullptr && kernel == nullptr) {
135 return nullptr;
136 }
137 int op_type;
138 if (kernel == nullptr) {
139 op_type = static_cast<int>(primitive->value_type());
140 } else {
141 op_type = static_cast<int>(kernel->type());
142 }
143 if (op_type > PrimitiveType_MAX || op_type <= PrimitiveType_MIN) {
144 return nullptr;
145 }
146 if (op_type == schema::PrimitiveType_Custom) {
147 return GetCustomKernelInterface(primitive, kernel);
148 }
149
150 std::unique_lock<std::mutex> lock(mutex_);
151 auto kernel_interface = GetCacheInterface(provider, op_type);
152 if (kernel_interface != nullptr) {
153 return kernel_interface;
154 }
155 auto iter = kernel_creators_.find(provider);
156 if (iter == kernel_creators_.end()) {
157 return nullptr;
158 }
159
160 auto creator = iter->second[op_type];
161 if (creator != nullptr) {
162 kernel_interface = creator();
163 kernel_interfaces_[provider][op_type] = kernel_interface;
164 return kernel_interface;
165 }
166 return nullptr;
167 }
168
Reg(const std::string & provider,int op_type,const KernelInterfaceCreator creator)169 Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
170 if (creator == nullptr) {
171 MS_LOG(ERROR) << "kernel interface creator is nullptr.";
172 return kLiteNullptr;
173 }
174 if (op_type <= PrimitiveType_MIN || op_type > PrimitiveType_MAX) {
175 MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << PrimitiveType_MAX;
176 return kLiteParamInvalid;
177 }
178
179 if (provider.empty()) {
180 MS_LOG(ERROR) << "Input provider is empty!";
181 return kLiteParamInvalid;
182 }
183 std::unique_lock<std::mutex> lock(mutex_);
184 auto iter = kernel_creators_.find(provider);
185 if (iter == kernel_creators_.end()) {
186 if (kernel_creators_.size() >= kMaxProviderNum) {
187 MS_LOG(ERROR) << "register too many provider!";
188 return kLiteError;
189 }
190 kernel_creators_[provider] =
191 reinterpret_cast<KernelInterfaceCreator *>(calloc(kMaxKernelNum, sizeof(KernelInterfaceCreator)));
192 if (kernel_creators_[provider] == nullptr) {
193 MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!";
194 return kLiteError;
195 }
196 }
197
198 kernel_creators_[provider][op_type] = creator;
199 return kSuccess;
200 }
201
~KernelInterfaceRegistry()202 KernelInterfaceRegistry::~KernelInterfaceRegistry() {
203 for (auto &&item : kernel_creators_) {
204 free(item.second);
205 item.second = nullptr;
206 }
207 }
208 } // namespace registry
209 } // namespace mindspore
210