• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "bolt/bolt_kernel_manager.h"
18 #include "bolt/bolt_parameter_manager.h"
19 
20 namespace mindspore::kernel::bolt {
BoltSupportKernel(int op_type,TypeId data_type)21 bool BoltSupportKernel(int op_type, TypeId data_type) {
22   auto creator = BoltKernelRegistry::GetInstance()->Creator({op_type, data_type});
23   if (creator != nullptr) {
24     return true;
25   }
26   return false;
27 }
28 
BoltKernelRegistry(const ParameterSpec & param_spec,const std::vector<lite::Tensor * > & inputs,const std::vector<lite::Tensor * > & outputs,const lite::InnerContext * ctx,kernel::KernelKey * key)29 LiteKernel *BoltKernelRegistry(const ParameterSpec &param_spec, const std::vector<lite::Tensor *> &inputs,
30                                const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
31                                kernel::KernelKey *key) {
32   auto creator = BoltKernelRegistry::GetInstance()->Creator({key->type, key->data_type});
33   LiteKernel *kernel = nullptr;
34   if (creator != nullptr) {
35     kernel = creator(param_spec, inputs, outputs, ctx);
36   }
37   if (kernel == nullptr) {
38     MS_LOG(DEBUG) << "Create bolt kernel failed!";
39     return nullptr;
40   }
41   key->format = BoltKernelRegistry::GetInstance()->GetKernelFormat({key->type, key->data_type});
42   return kernel;
43 }
44 
BoltKernelRegistry(OpParameter * parameter,const std::vector<lite::Tensor * > & inputs,const std::vector<lite::Tensor * > & outputs,const lite::InnerContext * ctx,kernel::KernelKey * key)45 LiteKernel *BoltKernelRegistry(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
46                                const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
47                                kernel::KernelKey *key) {
48   // convert OpParameter to ParameterSpec
49   auto param_spec = BoltParameterRegistry::GetInstance()->CreateBoltParameter(parameter);
50   if (param_spec == nullptr) {
51     MS_LOG(DEBUG) << "Create bolt ParameterSpec failed!";
52     return nullptr;
53   }
54   return BoltKernelRegistry(*param_spec, inputs, outputs, ctx, key);
55 }
56 }  // namespace mindspore::kernel::bolt
57