• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/registry/kernel_interface_registry.h"
17 #include <memory>
18 #include "include/registry/register_kernel_interface.h"
19 #include "include/errorcode.h"
20 #include "src/common/log_adapter.h"
21 #include "src/common/version_manager.h"
22 #include "schema/model_generated.h"
23 #include "include/api/kernel.h"
24 
25 using mindspore::registry::KernelInterfaceCreator;
26 using mindspore::schema::PrimitiveType_MAX;
27 using mindspore::schema::PrimitiveType_MIN;
28 namespace mindspore {
29 namespace registry {
30 namespace {
31 static constexpr auto kMaxProviderNum = 10;
32 static constexpr auto KMaxCustomTypeNum = 200;
33 static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1;
GetCustomType(const schema::Primitive * primitive)34 std::string GetCustomType(const schema::Primitive *primitive) {
35   auto param = primitive->value_as_Custom();
36   if (param == nullptr || param->type() == nullptr) {
37     return "";
38   }
39 
40   return param->type()->str();
41 }
42 }  // namespace
43 
CustomReg(const std::string & provider,const std::string & type,const KernelInterfaceCreator creator)44 Status KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
45                                           const KernelInterfaceCreator creator) {
46   if (creator == nullptr) {
47     MS_LOG(ERROR) << "custom kernel interface creator is nullptr.";
48     return kLiteNullptr;
49   }
50   auto provider_iter = custom_creators_.find(provider);
51   if (provider_iter == custom_creators_.end() && custom_creators_.size() >= kMaxProviderNum) {
52     MS_LOG(ERROR) << "register too many provider!";
53     return kLiteError;
54   }
55   if (provider_iter != custom_creators_.end()) {
56     auto type_iter = provider_iter->second.find(type);
57     if (type_iter == provider_iter->second.end() && provider_iter->second.size() >= KMaxCustomTypeNum) {
58       MS_LOG(ERROR) << "register too many custom type!";
59       return kLiteError;
60     }
61   }
62   custom_creators_[provider][type] = creator;
63   return kSuccess;
64 }
65 
GetCacheInterface(const std::string & provider,int op_type)66 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCacheInterface(const std::string &provider,
67                                                                                     int op_type) {
68   if (provider.empty()) {
69     return nullptr;
70   }
71   auto provider_iter = kernel_interfaces_.find(provider);
72   if (provider_iter != kernel_interfaces_.end()) {
73     auto kernel_iter = provider_iter->second.find(op_type);
74     if (kernel_iter != provider_iter->second.end()) {
75       return kernel_iter->second;
76     }
77   }
78   return nullptr;
79 }
80 
GetCustomCacheInterface(const std::string & provider,const std::string & type)81 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomCacheInterface(const std::string &provider,
82                                                                                           const std::string &type) {
83   if (provider.empty()) {
84     return nullptr;
85   }
86   auto provider_iter = custom_kernels_.find(provider);
87   if (provider_iter == custom_kernels_.end()) {
88     return nullptr;
89   }
90   auto kernel_iter = provider_iter->second.find(type);
91   if (kernel_iter != provider_iter->second.end()) {
92     return kernel_iter->second;
93   }
94   return nullptr;
95 }
96 
GetCustomKernelInterface(const schema::Primitive * primitive,const kernel::Kernel * kernel)97 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKernelInterface(
98   const schema::Primitive *primitive, const kernel::Kernel *kernel) {
99   std::unique_lock<std::mutex> lock(mutex_);
100   std::string type;
101   if (kernel == nullptr) {
102     type = GetCustomType(primitive);
103   } else {
104     type = kernel->GetAttr("type");
105   }
106   for (auto &&item : custom_creators_) {
107     auto &&provider = item.first;
108     auto kernel_interface = GetCustomCacheInterface(provider, type);
109     if (kernel_interface != nullptr) {
110       return kernel_interface;
111     }
112     auto provider_iter = custom_creators_.find(provider);
113     if (provider_iter == custom_creators_.end()) {
114       return nullptr;
115     }
116     auto creator_iter = provider_iter->second.find(type);
117     if (creator_iter != provider_iter->second.end()) {
118       if (creator_iter->second == nullptr) {
119         MS_LOG(ERROR) << "custom kernel interface creator is nullptr.";
120         return nullptr;
121       }
122       kernel_interface = creator_iter->second();
123       custom_kernels_[provider][type] = kernel_interface;
124       return kernel_interface;
125     }
126   }
127 
128   return nullptr;
129 }
130 
GetKernelInterface(const std::string & provider,const schema::Primitive * primitive,const kernel::Kernel * kernel)131 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(const std::string &provider,
132                                                                                      const schema::Primitive *primitive,
133                                                                                      const kernel::Kernel *kernel) {
134   if (primitive == nullptr && kernel == nullptr) {
135     return nullptr;
136   }
137   int op_type;
138   if (kernel == nullptr) {
139     op_type = static_cast<int>(primitive->value_type());
140   } else {
141     op_type = static_cast<int>(kernel->type());
142   }
143   if (op_type > PrimitiveType_MAX || op_type <= PrimitiveType_MIN) {
144     return nullptr;
145   }
146   if (op_type == schema::PrimitiveType_Custom) {
147     return GetCustomKernelInterface(primitive, kernel);
148   }
149 
150   std::unique_lock<std::mutex> lock(mutex_);
151   auto kernel_interface = GetCacheInterface(provider, op_type);
152   if (kernel_interface != nullptr) {
153     return kernel_interface;
154   }
155   auto iter = kernel_creators_.find(provider);
156   if (iter == kernel_creators_.end()) {
157     return nullptr;
158   }
159 
160   auto creator = iter->second[op_type];
161   if (creator != nullptr) {
162     kernel_interface = creator();
163     kernel_interfaces_[provider][op_type] = kernel_interface;
164     return kernel_interface;
165   }
166   return nullptr;
167 }
168 
Reg(const std::string & provider,int op_type,const KernelInterfaceCreator creator)169 Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
170   if (creator == nullptr) {
171     MS_LOG(ERROR) << "kernel interface creator is nullptr.";
172     return kLiteNullptr;
173   }
174   if (op_type <= PrimitiveType_MIN || op_type > PrimitiveType_MAX) {
175     MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << PrimitiveType_MAX;
176     return kLiteParamInvalid;
177   }
178 
179   if (provider.empty()) {
180     MS_LOG(ERROR) << "Input provider is empty!";
181     return kLiteParamInvalid;
182   }
183   std::unique_lock<std::mutex> lock(mutex_);
184   auto iter = kernel_creators_.find(provider);
185   if (iter == kernel_creators_.end()) {
186     if (kernel_creators_.size() >= kMaxProviderNum) {
187       MS_LOG(ERROR) << "register too many provider!";
188       return kLiteError;
189     }
190     kernel_creators_[provider] =
191       reinterpret_cast<KernelInterfaceCreator *>(calloc(kMaxKernelNum, sizeof(KernelInterfaceCreator)));
192     if (kernel_creators_[provider] == nullptr) {
193       MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!";
194       return kLiteError;
195     }
196   }
197 
198   kernel_creators_[provider][op_type] = creator;
199   return kSuccess;
200 }
201 
~KernelInterfaceRegistry()202 KernelInterfaceRegistry::~KernelInterfaceRegistry() {
203   for (auto &&item : kernel_creators_) {
204     free(item.second);
205     item.second = nullptr;
206   }
207 }
208 }  // namespace registry
209 }  // namespace mindspore
210