1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_ 18 #define MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_ 19 20 #include <set> 21 #include <string> 22 #include <vector> 23 #include <memory> 24 #include "schema/model_generated.h" 25 #include "include/api/context.h" 26 #include "include/api/types.h" 27 #include "include/api/kernel.h" 28 #include "include/api/data_type.h" 29 #include "include/api/status.h" 30 31 namespace mindspore { 32 namespace registry { 33 /// \brief KernelDesc defined kernel's basic attribute. 34 struct KernelDesc { 35 DataType data_type; /**< kernel data type argument */ 36 int type; /**< op type argument */ 37 std::string arch; /**< deviceType argument */ 38 std::string provider; /**< user identification argument */ 39 }; 40 41 /// \brief CreateKernel Defined a functor to create a kernel. 42 /// 43 /// \param[in] inputs Define input tensors of kernel. 44 /// \param[in] outputs Define output tensors of kernel. 45 /// \param[in] primitive Define attributes of op. 46 /// \param[in] ctx Define for holding environment variables during runtime. 47 /// 48 /// \return Smart Pointer of kernel. 49 using CreateKernel = std::function<std::shared_ptr<kernel::Kernel>( 50 const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs, const schema::Primitive *primitive, 51 const mindspore::Context *ctx)>; 52 53 /// \brief RegisterKernel Defined registration of kernel. 54 class MS_API RegisterKernel { 55 public: 56 /// \brief Static method to register kernel which is correspondng to an ordinary op. 57 /// 58 /// \param[in] arch Define deviceType, such as CPU. 59 /// \param[in] provider Define the identification of user. 60 /// \param[in] data_type Define kernel's input data type. 61 /// \param[in] type Define the ordinary op type. 62 /// \param[in] creator Define a function pointer to create a kernel. 63 /// 64 /// \return Status as a status identification of registering. 65 static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, 66 const CreateKernel creator); 67 68 /// \brief Static method to register kernel which is corresponding to custom op. 69 /// 70 /// \param[in] arch Define deviceType, such as CPU. 71 /// \param[in] provider Define the identification of user. 72 /// \param[in] data_type Define kernel's input data type. 73 /// \param[in] type Define the concrete type of a custom op. 74 /// \param[in] creator Define a function pointer to create a kernel. 75 /// 76 /// \return Status as a status identification of registering. 77 static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, 78 const std::string &type, const CreateKernel creator); 79 80 /// \brief Static methon to get a kernel's create function. 81 /// 82 /// \param[in] desc Define kernel's basic attribute. 83 /// \param[in] primitive Define the primitive of kernel generated by flatbuffers. 84 /// 85 /// \return Function pointer to create a kernel. 86 static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc); 87 }; 88 89 /// \brief KernelReg Defined registration class of kernel. 90 class MS_API KernelReg { 91 public: 92 /// \brief Destructor of KernelReg. 93 ~KernelReg() = default; 94 95 /// \brief Method to register ordinary op. 96 /// 97 /// \param[in] arch Define deviceType, such as CPU. 98 /// \param[in] provider Define the identification of user. 99 /// \param[in] data_type Define kernel's input data type. 100 /// \param[in] op_type Define the ordinary op type. 101 /// \param[in] creator Define a function pointer to create a kernel. KernelReg(const std::string & arch,const std::string & provider,DataType data_type,int op_type,const CreateKernel creator)102 KernelReg(const std::string &arch, const std::string &provider, DataType data_type, int op_type, 103 const CreateKernel creator) { 104 RegisterKernel::RegKernel(arch, provider, data_type, op_type, creator); 105 } 106 107 /// \brief Method to register customized op. 108 /// 109 /// \param[in] arch Define deviceType, such as CPU. 110 /// \param[in] provider Define the identification of user. 111 /// \param[in] data_type Define kernel's input data type. 112 /// \param[in] op_type Define the concrete type of a custom op. 113 /// \param[in] creator Define a function pointer to create a kernel. KernelReg(const std::string & arch,const std::string & provider,DataType data_type,const std::string & op_type,const CreateKernel creator)114 KernelReg(const std::string &arch, const std::string &provider, DataType data_type, const std::string &op_type, 115 const CreateKernel creator) { 116 RegisterKernel::RegCustomKernel(arch, provider, data_type, op_type, creator); 117 } 118 }; 119 120 /// \brief Defined registering macro to register ordinary op kernel, which called by user directly. 121 /// 122 /// \param[in] arch Define deviceType, such as CPU. 123 /// \param[in] provider Define the identification of user. 124 /// \param[in] data_type Define kernel's input data type. 125 /// \param[in] op_type Define the ordinary op type. 126 /// \param[in] creator Define a function pointer to create a kernel. 127 #define REGISTER_KERNEL(arch, provider, data_type, op_type, creator) \ 128 namespace { \ 129 static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \ 130 op_type, creator); \ 131 } // namespace 132 133 /// \brief Defined registering macro to register custom op kernel, which called by user directly. 134 /// 135 /// \param[in] arch Define deviceType, such as CPU. 136 /// \param[in] provider Define the identification of user. 137 /// \param[in] data_type Define kernel's input data type. 138 /// \param[in] op_type Define the concrete type of a custom op. 139 /// \param[in] creator Define a function pointer to create a kernel. 140 #define REGISTER_CUSTOM_KERNEL(arch, provider, data_type, op_type, creator) \ 141 namespace { \ 142 static mindspore::registry::KernelReg g_##arch##provider##data_type##op_type##kernelReg(#arch, #provider, data_type, \ 143 #op_type, creator); \ 144 } // namespace 145 } // namespace registry 146 } // namespace mindspore 147 148 #endif // MINDSPORE_LITE_INCLUDE_REGISTRY_REGISTER_KERNEL_H_ 149