• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BOLT_BOLT_KERNEL_MANAGER_H_
18 #define MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BOLT_BOLT_KERNEL_MANAGER_H_
19 
20 #include <map>
21 #include <vector>
22 #include <utility>
23 #include "src/executor/kernel_exec.h"
24 #include "bolt/common/uni/include/parameter_spec.h"
25 
26 namespace mindspore::kernel::bolt {
27 struct BoltKeyDesc {
28   int op_;
29   TypeId dt_;
30   bool operator<(const BoltKeyDesc &comp) const { return (op_ != comp.op_) ? (op_ < comp.op_) : (dt_ < comp.dt_); }
31 };
32 
33 typedef LiteKernel *(*BoltCreator)(const ParameterSpec &param_spec, const std::vector<lite::Tensor *> &in,
34                                    const std::vector<lite::Tensor *> &out, const lite::InnerContext *ctx);
35 
36 class BoltKernelRegistry {
37  public:
GetInstance()38   static BoltKernelRegistry *GetInstance() {
39     static BoltKernelRegistry instance;
40     return &instance;
41   }
Register(BoltKeyDesc desc,Format df,BoltCreator creator)42   void Register(BoltKeyDesc desc, Format df, BoltCreator creator) { bolt_map_[desc] = std::make_pair(df, creator); }
43 
Creator(BoltKeyDesc desc)44   BoltCreator Creator(BoltKeyDesc desc) {
45     auto iter = bolt_map_.find(desc);
46     if (iter != bolt_map_.end()) {
47       return iter->second.second;
48     }
49     return nullptr;
50   }
51 
GetKernelFormat(BoltKeyDesc desc)52   Format GetKernelFormat(BoltKeyDesc desc) {
53     auto iter = bolt_map_.find(desc);
54     if (iter != bolt_map_.end()) {
55       return iter->second.first;
56     }
57     return DEFAULT_FORMAT;
58   }
59 
60  protected:
61   std::map<BoltKeyDesc, std::pair<Format, BoltCreator>> bolt_map_;
62 };
63 
64 class BoltKernelRegistrar {
65  public:
BoltKernelRegistrar(int op_type,TypeId data_type,Format format,BoltCreator creator)66   BoltKernelRegistrar(int op_type, TypeId data_type, Format format, BoltCreator creator) {
67     BoltKernelRegistry::GetInstance()->Register({op_type, data_type}, format, creator);
68   }
69   ~BoltKernelRegistrar() = default;
70 };
71 
72 template <class T>
BoltOpt(const ParameterSpec & param_spec,const std::vector<lite::Tensor * > & in,const std::vector<lite::Tensor * > & out,const lite::InnerContext * ctx)73 LiteKernel *BoltOpt(const ParameterSpec &param_spec, const std::vector<lite::Tensor *> &in,
74                     const std::vector<lite::Tensor *> &out, const lite::InnerContext *ctx) {
75   auto *kernel = new (std::nothrow) T(param_spec, in, out, ctx);
76   return kernel;
77 }
78 
79 #define BLOT_REG_KERNEL(op_type, data_type, format, creator) \
80   static BoltKernelRegistrar g_kernel##op_type##data_type##kernelReg(op_type, data_type, format, creator);
81 
82 bool BoltSupportKernel(int op_type, TypeId data_type);
83 
84 // registry for extendrt
85 LiteKernel *BoltKernelRegistry(const ParameterSpec &param_spec, const std::vector<lite::Tensor *> &inputs,
86                                const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
87                                kernel::KernelKey *key);
88 
89 // registry for litert
90 LiteKernel *BoltKernelRegistry(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
91                                const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
92                                kernel::KernelKey *key);
93 }  // namespace mindspore::kernel::bolt
94 #endif  // MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BOLT_BOLT_KERNEL_MANAGER_H_
95