1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/registry/kernel_interface_registry.h"
17 #include <memory>
18 #include "include/registry/register_kernel_interface.h"
19 #include "include/errorcode.h"
20 #include "src/common/log_adapter.h"
21 #include "src/common/version_manager.h"
22 #include "schema/model_generated.h"
23 #include "include/api/kernel.h"
24
25 using mindspore::registry::KernelInterfaceCreator;
26 using mindspore::schema::PrimitiveType_MAX;
27 using mindspore::schema::PrimitiveType_MIN;
28 namespace mindspore {
29 namespace registry {
30 namespace {
31 static constexpr auto kMaxProviderNum = 10;
32 static constexpr auto KMaxCustomTypeNum = 200;
33 static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1;
GetCustomType(const schema::Primitive * primitive)34 std::string GetCustomType(const schema::Primitive *primitive) {
35 auto param = primitive->value_as_Custom();
36 if (param == nullptr || param->type() == nullptr) {
37 return "";
38 }
39
40 return param->type()->str();
41 }
42 } // namespace
43
CustomReg(const std::string & provider,const std::string & type,const KernelInterfaceCreator creator)44 Status KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
45 const KernelInterfaceCreator creator) {
46 auto provider_iter = custom_creators_.find(provider);
47 if (provider_iter == custom_creators_.end() && custom_creators_.size() >= kMaxProviderNum) {
48 MS_LOG(ERROR) << "register too many provider!";
49 return kLiteError;
50 }
51 if (provider_iter != custom_creators_.end()) {
52 auto type_iter = provider_iter->second.find(type);
53 if (type_iter == provider_iter->second.end() && provider_iter->second.size() >= KMaxCustomTypeNum) {
54 MS_LOG(ERROR) << "register too many custom type!";
55 return kLiteError;
56 }
57 }
58 custom_creators_[provider][type] = creator;
59 return kSuccess;
60 }
61
GetCacheInterface(const std::string & provider,int op_type)62 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCacheInterface(const std::string &provider,
63 int op_type) {
64 if (provider.empty()) {
65 return nullptr;
66 }
67 auto provider_iter = kernel_interfaces_.find(provider);
68 if (provider_iter != kernel_interfaces_.end()) {
69 auto kernel_iter = provider_iter->second.find(op_type);
70 if (kernel_iter != provider_iter->second.end()) {
71 return kernel_iter->second;
72 }
73 }
74 return nullptr;
75 }
76
GetCustomCacheInterface(const std::string & provider,const std::string & type)77 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomCacheInterface(const std::string &provider,
78 const std::string &type) {
79 if (provider.empty()) {
80 return nullptr;
81 }
82 auto provider_iter = custom_kernels_.find(provider);
83 if (provider_iter == custom_kernels_.end()) {
84 return nullptr;
85 }
86 auto kernel_iter = provider_iter->second.find(type);
87 if (kernel_iter != provider_iter->second.end()) {
88 return kernel_iter->second;
89 }
90 return nullptr;
91 }
92
GetCustomKernelInterface(const schema::Primitive * primitive,const kernel::Kernel * kernel)93 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKernelInterface(
94 const schema::Primitive *primitive, const kernel::Kernel *kernel) {
95 std::unique_lock<std::mutex> lock(mutex_);
96 std::string type;
97 if (kernel == nullptr) {
98 type = GetCustomType(primitive);
99 } else {
100 type = kernel->GetAttr("type");
101 }
102 for (auto &&item : custom_creators_) {
103 auto &&provider = item.first;
104 auto kernel_interface = GetCustomCacheInterface(provider, type);
105 if (kernel_interface != nullptr) {
106 return kernel_interface;
107 }
108 auto provider_iter = custom_creators_.find(provider);
109 if (provider_iter == custom_creators_.end()) {
110 return nullptr;
111 }
112 auto creator_iter = provider_iter->second.find(type);
113 if (creator_iter != provider_iter->second.end()) {
114 kernel_interface = creator_iter->second();
115 custom_kernels_[provider][type] = kernel_interface;
116 return kernel_interface;
117 }
118 }
119
120 return nullptr;
121 }
122
GetKernelInterface(const std::string & provider,const schema::Primitive * primitive,const kernel::Kernel * kernel)123 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(const std::string &provider,
124 const schema::Primitive *primitive,
125 const kernel::Kernel *kernel) {
126 if (primitive == nullptr && kernel == nullptr) {
127 return nullptr;
128 }
129 int op_type;
130 if (kernel == nullptr) {
131 op_type = static_cast<int>(primitive->value_type());
132 } else {
133 op_type = static_cast<int>(kernel->type());
134 }
135 if (op_type > PrimitiveType_MAX || op_type <= PrimitiveType_MIN) {
136 return nullptr;
137 }
138 if (op_type == schema::PrimitiveType_Custom) {
139 return GetCustomKernelInterface(primitive, kernel);
140 }
141
142 std::unique_lock<std::mutex> lock(mutex_);
143 auto kernel_interface = GetCacheInterface(provider, op_type);
144 if (kernel_interface != nullptr) {
145 return kernel_interface;
146 }
147 auto iter = kernel_creators_.find(provider);
148 if (iter == kernel_creators_.end()) {
149 return nullptr;
150 }
151
152 auto creator = iter->second[op_type];
153 if (creator != nullptr) {
154 kernel_interface = creator();
155 kernel_interfaces_[provider][op_type] = kernel_interface;
156 return kernel_interface;
157 }
158 return nullptr;
159 }
160
Reg(const std::string & provider,int op_type,const KernelInterfaceCreator creator)161 Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
162 if (op_type <= PrimitiveType_MIN || op_type > PrimitiveType_MAX) {
163 MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << PrimitiveType_MAX;
164 return kLiteParamInvalid;
165 }
166
167 if (provider.empty()) {
168 MS_LOG(ERROR) << "Input provider is empty!";
169 return kLiteParamInvalid;
170 }
171 std::unique_lock<std::mutex> lock(mutex_);
172 auto iter = kernel_creators_.find(provider);
173 if (iter == kernel_creators_.end()) {
174 if (kernel_creators_.size() >= kMaxProviderNum) {
175 MS_LOG(ERROR) << "register too many provider!";
176 return kLiteError;
177 }
178 kernel_creators_[provider] =
179 reinterpret_cast<KernelInterfaceCreator *>(calloc(kMaxKernelNum, sizeof(KernelInterfaceCreator)));
180 if (kernel_creators_[provider] == nullptr) {
181 MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!";
182 return kLiteError;
183 }
184 }
185
186 kernel_creators_[provider][op_type] = creator;
187 return kSuccess;
188 }
189
~KernelInterfaceRegistry()190 KernelInterfaceRegistry::~KernelInterfaceRegistry() {
191 for (auto &&item : kernel_creators_) {
192 free(item.second);
193 item.second = nullptr;
194 }
195 }
196 } // namespace registry
197 } // namespace mindspore
198