• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_
18 #define MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_
19 
20 #include <set>
21 #include <string>
22 #include <vector>
23 #include <memory>
24 #include "schema/model_generated.h"
25 #include "include/api/context.h"
26 #include "include/api/types.h"
27 #include "include/api/kernel.h"
28 #include "include/api/data_type.h"
29 #include "include/api/status.h"
30 
31 namespace mindspore {
32 namespace registry {
33 /// \brief KernelDesc defined kernel's basic attribute.
34 struct KernelDesc {
35   DataType data_type;   /**< kernel data type argument */
36   int type;             /**< op type argument */
37   std::string arch;     /**< deviceType argument */
38   std::string provider; /**< user identification argument */
39 };
40 
41 /// \brief CreateKernel Defined a functor to create a kernel.
42 ///
43 /// \param[in] inputs Define input tensors of kernel.
44 /// \param[in] outputs Define output tensors of kernel.
45 /// \param[in] primitive Define attributes of op.
46 /// \param[in] ctx Define for holding environment variables during runtime.
47 ///
48 /// \return Smart Pointer of kernel.
49 using CreateKernel = std::function<std::shared_ptr<kernel::Kernel>(
50   const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs, const schema::Primitive *primitive,
51   const mindspore::Context *ctx)>;
52 
53 /// \brief RegisterKernel Defined registration of kernel.
54 class MS_API RegisterKernel {
55  public:
56   /// \brief Static method to register kernel which is correspondng to an ordinary op.
57   ///
58   /// \param[in] arch Define deviceType, such as CPU.
59   /// \param[in] provider Define the identification of user.
60   /// \param[in] data_type Define kernel's input data type.
61   /// \param[in] type Define the ordinary op type.
62   /// \param[in] creator Define a function pointer to create a kernel.
63   ///
64   /// \return Status as a status identification of registering.
65   static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
66                           const CreateKernel creator);
67 
68   /// \brief Static method to register kernel which is corresponding to custom op.
69   ///
70   /// \param[in] arch Define deviceType, such as CPU.
71   /// \param[in] provider Define the identification of user.
72   /// \param[in] data_type Define kernel's input data type.
73   /// \param[in] type Define the concrete type of a custom op.
74   /// \param[in] creator Define a function pointer to create a kernel.
75   ///
76   /// \return Status as a status identification of registering.
77   static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
78                                 const std::string &type, const CreateKernel creator);
79 
80   /// \brief Static methon to get a kernel's create function.
81   ///
82   /// \param[in] desc Define kernel's basic attribute.
83   /// \param[in] primitive Define the primitive of kernel generated by flatbuffers.
84   ///
85   /// \return Function pointer to create a kernel.
86   static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc);
87 };
88 
89 /// \brief KernelReg Defined registration class of kernel.
90 class MS_API KernelReg {
91  public:
92   /// \brief Destructor of KernelReg.
93   ~KernelReg() = default;
94 
95   /// \brief Method to register ordinary op.
96   ///
97   /// \param[in] arch Define deviceType, such as CPU.
98   /// \param[in] provider Define the identification of user.
99   /// \param[in] data_type Define kernel's input data type.
100   /// \param[in] op_type Define the ordinary op type.
101   /// \param[in] creator Define a function pointer to create a kernel.
KernelReg(const std::string & arch,const std::string & provider,DataType data_type,int op_type,const CreateKernel creator)102   KernelReg(const std::string &arch, const std::string &provider, DataType data_type, int op_type,
103             const CreateKernel creator) {
104     RegisterKernel::RegKernel(arch, provider, data_type, op_type, creator);
105   }
106 
107   /// \brief Method to register customized op.
108   ///
109   /// \param[in] arch Define deviceType, such as CPU.
110   /// \param[in] provider Define the identification of user.
111   /// \param[in] data_type Define kernel's input data type.
112   /// \param[in] op_type Define the concrete type of a custom op.
113   /// \param[in] creator Define a function pointer to create a kernel.
KernelReg(const std::string & arch,const std::string & provider,DataType data_type,const std::string & op_type,const CreateKernel creator)114   KernelReg(const std::string &arch, const std::string &provider, DataType data_type, const std::string &op_type,
115             const CreateKernel creator) {
116     RegisterKernel::RegCustomKernel(arch, provider, data_type, op_type, creator);
117   }
118 };
119 
120 /// \brief Defined registering macro to register ordinary op kernel, which called by user directly.
121 ///
122 /// \param[in] arch Define deviceType, such as CPU.
123 /// \param[in] provider Define the identification of user.
124 /// \param[in] data_type Define kernel's input data type.
125 /// \param[in] op_type Define the ordinary op type.
126 /// \param[in] creator Define a function pointer to create a kernel.
127 #define REGISTER_KERNEL(arch, provider, data_type, op_type, creator)                                                   \
128   namespace {                                                                                                          \
129   static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
130                                                                                           op_type, creator);           \
131   }  // namespace
132 
133 /// \brief Defined registering macro to register custom op kernel, which called by user directly.
134 ///
135 /// \param[in] arch Define deviceType, such as CPU.
136 /// \param[in] provider Define the identification of user.
137 /// \param[in] data_type Define kernel's input data type.
138 /// \param[in] op_type Define the concrete type of a custom op.
139 /// \param[in] creator Define a function pointer to create a kernel.
140 #define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator)                                            \
141   namespace {                                                                                                          \
142   static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \
143                                                                                           #op_type, creator);          \
144   }  // namespace
145 }  // namespace registry
146 }  // namespace mindspore
147 
148 #endif  // MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_
149