• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/registry/kernel_interface_registry.h"
17 #include <memory>
18 #include "include/registry/register_kernel_interface.h"
19 #include "include/errorcode.h"
20 #include "src/common/log_adapter.h"
21 #include "src/common/version_manager.h"
22 #include "schema/model_generated.h"
23 #include "include/api/kernel.h"
24 
25 using mindspore::registry::KernelInterfaceCreator;
26 using mindspore::schema::PrimitiveType_MAX;
27 using mindspore::schema::PrimitiveType_MIN;
28 namespace mindspore {
29 namespace registry {
30 namespace {
31 static constexpr auto kMaxProviderNum = 10;
32 static constexpr auto KMaxCustomTypeNum = 200;
33 static const auto kMaxKernelNum = PrimitiveType_MAX - PrimitiveType_MIN + 1;
GetCustomType(const schema::Primitive * primitive)34 std::string GetCustomType(const schema::Primitive *primitive) {
35   auto param = primitive->value_as_Custom();
36   if (param == nullptr || param->type() == nullptr) {
37     return "";
38   }
39 
40   return param->type()->str();
41 }
42 }  // namespace
43 
CustomReg(const std::string & provider,const std::string & type,const KernelInterfaceCreator creator)44 Status KernelInterfaceRegistry::CustomReg(const std::string &provider, const std::string &type,
45                                           const KernelInterfaceCreator creator) {
46   auto provider_iter = custom_creators_.find(provider);
47   if (provider_iter == custom_creators_.end() && custom_creators_.size() >= kMaxProviderNum) {
48     MS_LOG(ERROR) << "register too many provider!";
49     return kLiteError;
50   }
51   if (provider_iter != custom_creators_.end()) {
52     auto type_iter = provider_iter->second.find(type);
53     if (type_iter == provider_iter->second.end() && provider_iter->second.size() >= KMaxCustomTypeNum) {
54       MS_LOG(ERROR) << "register too many custom type!";
55       return kLiteError;
56     }
57   }
58   custom_creators_[provider][type] = creator;
59   return kSuccess;
60 }
61 
GetCacheInterface(const std::string & provider,int op_type)62 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCacheInterface(const std::string &provider,
63                                                                                     int op_type) {
64   if (provider.empty()) {
65     return nullptr;
66   }
67   auto provider_iter = kernel_interfaces_.find(provider);
68   if (provider_iter != kernel_interfaces_.end()) {
69     auto kernel_iter = provider_iter->second.find(op_type);
70     if (kernel_iter != provider_iter->second.end()) {
71       return kernel_iter->second;
72     }
73   }
74   return nullptr;
75 }
76 
GetCustomCacheInterface(const std::string & provider,const std::string & type)77 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomCacheInterface(const std::string &provider,
78                                                                                           const std::string &type) {
79   if (provider.empty()) {
80     return nullptr;
81   }
82   auto provider_iter = custom_kernels_.find(provider);
83   if (provider_iter == custom_kernels_.end()) {
84     return nullptr;
85   }
86   auto kernel_iter = provider_iter->second.find(type);
87   if (kernel_iter != provider_iter->second.end()) {
88     return kernel_iter->second;
89   }
90   return nullptr;
91 }
92 
GetCustomKernelInterface(const schema::Primitive * primitive,const kernel::Kernel * kernel)93 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetCustomKernelInterface(
94   const schema::Primitive *primitive, const kernel::Kernel *kernel) {
95   std::unique_lock<std::mutex> lock(mutex_);
96   std::string type;
97   if (kernel == nullptr) {
98     type = GetCustomType(primitive);
99   } else {
100     type = kernel->GetAttr("type");
101   }
102   for (auto &&item : custom_creators_) {
103     auto &&provider = item.first;
104     auto kernel_interface = GetCustomCacheInterface(provider, type);
105     if (kernel_interface != nullptr) {
106       return kernel_interface;
107     }
108     auto provider_iter = custom_creators_.find(provider);
109     if (provider_iter == custom_creators_.end()) {
110       return nullptr;
111     }
112     auto creator_iter = provider_iter->second.find(type);
113     if (creator_iter != provider_iter->second.end()) {
114       kernel_interface = creator_iter->second();
115       custom_kernels_[provider][type] = kernel_interface;
116       return kernel_interface;
117     }
118   }
119 
120   return nullptr;
121 }
122 
GetKernelInterface(const std::string & provider,const schema::Primitive * primitive,const kernel::Kernel * kernel)123 std::shared_ptr<kernel::KernelInterface> KernelInterfaceRegistry::GetKernelInterface(const std::string &provider,
124                                                                                      const schema::Primitive *primitive,
125                                                                                      const kernel::Kernel *kernel) {
126   if (primitive == nullptr && kernel == nullptr) {
127     return nullptr;
128   }
129   int op_type;
130   if (kernel == nullptr) {
131     op_type = static_cast<int>(primitive->value_type());
132   } else {
133     op_type = static_cast<int>(kernel->type());
134   }
135   if (op_type > PrimitiveType_MAX || op_type <= PrimitiveType_MIN) {
136     return nullptr;
137   }
138   if (op_type == schema::PrimitiveType_Custom) {
139     return GetCustomKernelInterface(primitive, kernel);
140   }
141 
142   std::unique_lock<std::mutex> lock(mutex_);
143   auto kernel_interface = GetCacheInterface(provider, op_type);
144   if (kernel_interface != nullptr) {
145     return kernel_interface;
146   }
147   auto iter = kernel_creators_.find(provider);
148   if (iter == kernel_creators_.end()) {
149     return nullptr;
150   }
151 
152   auto creator = iter->second[op_type];
153   if (creator != nullptr) {
154     kernel_interface = creator();
155     kernel_interfaces_[provider][op_type] = kernel_interface;
156     return kernel_interface;
157   }
158   return nullptr;
159 }
160 
Reg(const std::string & provider,int op_type,const KernelInterfaceCreator creator)161 Status KernelInterfaceRegistry::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
162   if (op_type <= PrimitiveType_MIN || op_type > PrimitiveType_MAX) {
163     MS_LOG(ERROR) << "reg op_type invalid!op_type: " << op_type << ", max value: " << PrimitiveType_MAX;
164     return kLiteParamInvalid;
165   }
166 
167   if (provider.empty()) {
168     MS_LOG(ERROR) << "Input provider is empty!";
169     return kLiteParamInvalid;
170   }
171   std::unique_lock<std::mutex> lock(mutex_);
172   auto iter = kernel_creators_.find(provider);
173   if (iter == kernel_creators_.end()) {
174     if (kernel_creators_.size() >= kMaxProviderNum) {
175       MS_LOG(ERROR) << "register too many provider!";
176       return kLiteError;
177     }
178     kernel_creators_[provider] =
179       reinterpret_cast<KernelInterfaceCreator *>(calloc(kMaxKernelNum, sizeof(KernelInterfaceCreator)));
180     if (kernel_creators_[provider] == nullptr) {
181       MS_LOG(ERROR) << "malloc kernel dev delegate creator fail!";
182       return kLiteError;
183     }
184   }
185 
186   kernel_creators_[provider][op_type] = creator;
187   return kSuccess;
188 }
189 
~KernelInterfaceRegistry()190 KernelInterfaceRegistry::~KernelInterfaceRegistry() {
191   for (auto &&item : kernel_creators_) {
192     free(item.second);
193     item.second = nullptr;
194   }
195 }
196 }  // namespace registry
197 }  // namespace mindspore
198