• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_INTERFACE_H_
18 #define MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_INTERFACE_H_
19 
20 #include <set>
21 #include <string>
22 #include <vector>
23 #include <memory>
24 #include "include/kernel_interface.h"
25 #include "schema/model_generated.h"
26 
27 namespace mindspore {
28 namespace kernel {
29 class Kernel;
30 }
31 namespace registry {
32 /// \brief KernelInterfaceCreator defined a functor to create KernelInterface.
33 using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>;
34 
35 /// \brief RegisterKernelInterface defined registration and acquisition of KernelInterface.
36 class MS_API RegisterKernelInterface {
37  public:
38   /// \brief Static method to register op whose primitive type is custom.
39   ///
40   /// \param[in] provider Define the identification of user.
41   /// \param[in] op_type Define the concrete type of a custom op.
42   /// \param[in] creator Define the KernelInterface create function.
43   ///
44   /// \return Status as a status identification of registering.
45   inline static Status CustomReg(const std::string &provider, const std::string &op_type,
46                                  const KernelInterfaceCreator creator);
47 
48   /// \brief Static method to register op whose primitive type is ordinary.
49   ///
50   /// \param[in] provider Define the identification of user.
51   /// \param[in] op_type Define the ordinary op type.
52   /// \param[in] creator Define the KernelInterface create function.
53   ///
54   /// \return Status as a status identification of registering.
55   inline static Status Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator);
56 
57   /// \brief Static method to get registration of a certain op.
58   ///
59   /// \param[in] provider Define the identification of user.
60   /// \param[in] primitive Define the attributes of a certain op.
61   /// \param[in] kernel Define the kernel of a certain op.
62   ///
63   /// \return Boolean value to represent registration of a certain op is existing or not.
64   inline static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
65                                                                             const schema::Primitive *primitive,
66                                                                             const kernel::Kernel *kernel = nullptr);
67 
68  private:
69   static Status CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
70                           const KernelInterfaceCreator creator);
71   static Status Reg(const std::vector<char> &provider, int op_type, const KernelInterfaceCreator creator);
72   static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::vector<char> &provider,
73                                                                      const schema::Primitive *primitive,
74                                                                      const kernel::Kernel *kernel = nullptr);
75 };
76 
77 /// \brief KernelInterfaceReg defined registration class of KernelInterface.
78 class MS_API KernelInterfaceReg {
79  public:
80   /// \brief Constructor of KernelInterfaceReg to register an ordinary op.
81   ///
82   /// \param[in] provider Define the identification of user.
83   /// \param[in] op_type Define the ordinary op type.
84   /// \param[in] creator Define the KernelInterface create function.
KernelInterfaceReg(const std::string & provider,int op_type,const KernelInterfaceCreator creator)85   KernelInterfaceReg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
86     (void)RegisterKernelInterface::Reg(provider, op_type, creator);
87   }
88 
89   /// \brief Constructor of KernelInterfaceReg to register custom op.
90   ///
91   /// \param[in] provider Define the identification of user.
92   /// \param[in] op_type Define the concrete type of a custom op.
93   /// \param[in] creator Define the KernelInterface create function.
KernelInterfaceReg(const std::string & provider,const std::string & op_type,const KernelInterfaceCreator creator)94   KernelInterfaceReg(const std::string &provider, const std::string &op_type, const KernelInterfaceCreator creator) {
95     (void)RegisterKernelInterface::CustomReg(provider, op_type, creator);
96   }
97 
98   virtual ~KernelInterfaceReg() = default;
99 };
100 
CustomReg(const std::string & provider,const std::string & op_type,const KernelInterfaceCreator creator)101 Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
102                                           const KernelInterfaceCreator creator) {
103   return CustomReg(StringToChar(provider), StringToChar(op_type), creator);
104 }
105 
Reg(const std::string & provider,int op_type,const KernelInterfaceCreator creator)106 Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
107   return Reg(StringToChar(provider), op_type, creator);
108 }
109 
GetKernelInterface(const std::string & provider,const schema::Primitive * primitive,const kernel::Kernel * kernel)110 std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(const std::string &provider,
111                                                                                      const schema::Primitive *primitive,
112                                                                                      const kernel::Kernel *kernel) {
113   return GetKernelInterface(StringToChar(provider), primitive, kernel);
114 }
115 
116 /// \brief Defined registering macro to register ordinary op, which called by user directly.
117 ///
118 /// \param[in] provider Define the identification of user.
119 /// \param[in] op_type Define the ordinary op type.
120 /// \param[in] creator Define the KernelInterface create function.
121 #define REGISTER_KERNEL_INTERFACE(provider, op_type, creator)                                                    \
122   namespace {                                                                                                    \
123   static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator); \
124   }  // namespace
125 
126 /// \brief Defined registering macro to register custom op, which called by user directly.
127 ///
128 /// \param[in] provider Define the identification of user.
129 /// \param[in] op_type Define the concrete type of a custom op.
130 /// \param[in] creator Define the KernelInterface create function.
131 #define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator)                                           \
132   namespace {                                                                                                  \
133   static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, \
134                                                                                           creator);            \
135   }  // namespace
136 }  // namespace registry
137 }  // namespace mindspore
138 
139 #endif  // MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_INTERFACE_H_
140