1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_INTERFACE_H_
18 #define MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_INTERFACE_H_
19
20 #include <set>
21 #include <string>
22 #include <vector>
23 #include <memory>
24 #include "include/kernel_interface.h"
25 #include "schema/model_generated.h"
26
27 namespace mindspore {
28 namespace kernel {
29 class Kernel;
30 }
31 namespace registry {
32 /// \brief KernelInterfaceCreator defined a functor to create KernelInterface.
33 using KernelInterfaceCreator = std::function<std::shared_ptr<kernel::KernelInterface>()>;
34
35 /// \brief RegisterKernelInterface defined registration and acquisition of KernelInterface.
36 class MS_API RegisterKernelInterface {
37 public:
38 /// \brief Static method to register op whose primitive type is custom.
39 ///
40 /// \param[in] provider Define the identification of user.
41 /// \param[in] op_type Define the concrete type of a custom op.
42 /// \param[in] creator Define the KernelInterface create function.
43 ///
44 /// \return Status as a status identification of registering.
45 inline static Status CustomReg(const std::string &provider, const std::string &op_type,
46 const KernelInterfaceCreator creator);
47
48 /// \brief Static method to register op whose primitive type is ordinary.
49 ///
50 /// \param[in] provider Define the identification of user.
51 /// \param[in] op_type Define the ordinary op type.
52 /// \param[in] creator Define the KernelInterface create function.
53 ///
54 /// \return Status as a status identification of registering.
55 inline static Status Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator);
56
57 /// \brief Static method to get registration of a certain op.
58 ///
59 /// \param[in] provider Define the identification of user.
60 /// \param[in] primitive Define the attributes of a certain op.
61 /// \param[in] kernel Define the kernel of a certain op.
62 ///
63 /// \return Boolean value to represent registration of a certain op is existing or not.
64 inline static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
65 const schema::Primitive *primitive,
66 const kernel::Kernel *kernel = nullptr);
67
68 private:
69 static Status CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
70 const KernelInterfaceCreator creator);
71 static Status Reg(const std::vector<char> &provider, int op_type, const KernelInterfaceCreator creator);
72 static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::vector<char> &provider,
73 const schema::Primitive *primitive,
74 const kernel::Kernel *kernel = nullptr);
75 };
76
77 /// \brief KernelInterfaceReg defined registration class of KernelInterface.
78 class MS_API KernelInterfaceReg {
79 public:
80 /// \brief Constructor of KernelInterfaceReg to register an ordinary op.
81 ///
82 /// \param[in] provider Define the identification of user.
83 /// \param[in] op_type Define the ordinary op type.
84 /// \param[in] creator Define the KernelInterface create function.
KernelInterfaceReg(const std::string & provider,int op_type,const KernelInterfaceCreator creator)85 KernelInterfaceReg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
86 (void)RegisterKernelInterface::Reg(provider, op_type, creator);
87 }
88
89 /// \brief Constructor of KernelInterfaceReg to register custom op.
90 ///
91 /// \param[in] provider Define the identification of user.
92 /// \param[in] op_type Define the concrete type of a custom op.
93 /// \param[in] creator Define the KernelInterface create function.
KernelInterfaceReg(const std::string & provider,const std::string & op_type,const KernelInterfaceCreator creator)94 KernelInterfaceReg(const std::string &provider, const std::string &op_type, const KernelInterfaceCreator creator) {
95 (void)RegisterKernelInterface::CustomReg(provider, op_type, creator);
96 }
97
98 virtual ~KernelInterfaceReg() = default;
99 };
100
CustomReg(const std::string & provider,const std::string & op_type,const KernelInterfaceCreator creator)101 Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
102 const KernelInterfaceCreator creator) {
103 return CustomReg(StringToChar(provider), StringToChar(op_type), creator);
104 }
105
Reg(const std::string & provider,int op_type,const KernelInterfaceCreator creator)106 Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, const KernelInterfaceCreator creator) {
107 return Reg(StringToChar(provider), op_type, creator);
108 }
109
GetKernelInterface(const std::string & provider,const schema::Primitive * primitive,const kernel::Kernel * kernel)110 std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(const std::string &provider,
111 const schema::Primitive *primitive,
112 const kernel::Kernel *kernel) {
113 return GetKernelInterface(StringToChar(provider), primitive, kernel);
114 }
115
116 /// \brief Defined registering macro to register ordinary op, which called by user directly.
117 ///
118 /// \param[in] provider Define the identification of user.
119 /// \param[in] op_type Define the ordinary op type.
120 /// \param[in] creator Define the KernelInterface create function.
121 #define REGISTER_KERNEL_INTERFACE(provider, op_type, creator) \
122 namespace { \
123 static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_inter_reg(#provider, op_type, creator); \
124 } // namespace
125
126 /// \brief Defined registering macro to register custom op, which called by user directly.
127 ///
128 /// \param[in] provider Define the identification of user.
129 /// \param[in] op_type Define the concrete type of a custom op.
130 /// \param[in] creator Define the KernelInterface create function.
131 #define REGISTER_CUSTOM_KERNEL_INTERFACE(provider, op_type, creator) \
132 namespace { \
133 static mindspore::registry::KernelInterfaceReg g_##provider##op_type##_custom_inter_reg(#provider, #op_type, \
134 creator); \
135 } // namespace
136 } // namespace registry
137 } // namespace mindspore
138
139 #endif // MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_INTERFACE_H_
140