1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BOLT_BOLT_KERNEL_MANAGER_H_
18 #define MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BOLT_BOLT_KERNEL_MANAGER_H_
19
20 #include <map>
21 #include <vector>
22 #include <utility>
23 #include "src/executor/kernel_exec.h"
24 #include "bolt/common/uni/include/parameter_spec.h"
25
26 namespace mindspore::kernel::bolt {
27 struct BoltKeyDesc {
28 int op_;
29 TypeId dt_;
30 bool operator<(const BoltKeyDesc &comp) const { return (op_ != comp.op_) ? (op_ < comp.op_) : (dt_ < comp.dt_); }
31 };
32
33 typedef LiteKernel *(*BoltCreator)(const ParameterSpec ¶m_spec, const std::vector<lite::Tensor *> &in,
34 const std::vector<lite::Tensor *> &out, const lite::InnerContext *ctx);
35
36 class BoltKernelRegistry {
37 public:
GetInstance()38 static BoltKernelRegistry *GetInstance() {
39 static BoltKernelRegistry instance;
40 return &instance;
41 }
Register(BoltKeyDesc desc,Format df,BoltCreator creator)42 void Register(BoltKeyDesc desc, Format df, BoltCreator creator) { bolt_map_[desc] = std::make_pair(df, creator); }
43
Creator(BoltKeyDesc desc)44 BoltCreator Creator(BoltKeyDesc desc) {
45 auto iter = bolt_map_.find(desc);
46 if (iter != bolt_map_.end()) {
47 return iter->second.second;
48 }
49 return nullptr;
50 }
51
GetKernelFormat(BoltKeyDesc desc)52 Format GetKernelFormat(BoltKeyDesc desc) {
53 auto iter = bolt_map_.find(desc);
54 if (iter != bolt_map_.end()) {
55 return iter->second.first;
56 }
57 return DEFAULT_FORMAT;
58 }
59
60 protected:
61 std::map<BoltKeyDesc, std::pair<Format, BoltCreator>> bolt_map_;
62 };
63
64 class BoltKernelRegistrar {
65 public:
BoltKernelRegistrar(int op_type,TypeId data_type,Format format,BoltCreator creator)66 BoltKernelRegistrar(int op_type, TypeId data_type, Format format, BoltCreator creator) {
67 BoltKernelRegistry::GetInstance()->Register({op_type, data_type}, format, creator);
68 }
69 ~BoltKernelRegistrar() = default;
70 };
71
72 template <class T>
BoltOpt(const ParameterSpec & param_spec,const std::vector<lite::Tensor * > & in,const std::vector<lite::Tensor * > & out,const lite::InnerContext * ctx)73 LiteKernel *BoltOpt(const ParameterSpec ¶m_spec, const std::vector<lite::Tensor *> &in,
74 const std::vector<lite::Tensor *> &out, const lite::InnerContext *ctx) {
75 auto *kernel = new (std::nothrow) T(param_spec, in, out, ctx);
76 return kernel;
77 }
78
79 #define BLOT_REG_KERNEL(op_type, data_type, format, creator) \
80 static BoltKernelRegistrar g_kernel##op_type##data_type##kernelReg(op_type, data_type, format, creator);
81
82 bool BoltSupportKernel(int op_type, TypeId data_type);
83
84 // registry for extendrt
85 LiteKernel *BoltKernelRegistry(const ParameterSpec ¶m_spec, const std::vector<lite::Tensor *> &inputs,
86 const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
87 kernel::KernelKey *key);
88
89 // registry for litert
90 LiteKernel *BoltKernelRegistry(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
91 const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
92 kernel::KernelKey *key);
93 } // namespace mindspore::kernel::bolt
94 #endif // MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BOLT_BOLT_KERNEL_MANAGER_H_
95