1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_INCLUDE_API_KERNEL_API_H 18 #define MINDSPORE_INCLUDE_API_KERNEL_API_H 19 #include <vector> 20 #include <string> 21 #include <utility> 22 #include <map> 23 #include "include/api/types.h" 24 #include "include/api/status.h" 25 namespace mindspore { 26 class Context; 27 namespace kernel { 28 /// \brief The Kernel class is used to define a MindSpore Kernel. 29 class MS_API MSKernel { 30 public: 31 MSKernel() = default; 32 /// \brief Constructor. 33 /// 34 /// \param[in] inputs define the input tensors for kernel. 35 /// \param[in] outputs define the output tensors for kernel. 36 /// \param[in] primitive define the primitive of kernel. 37 /// \param[in] ctx define the context for kernel. MSKernel(std::vector<mindspore::MSTensor> inputs,std::vector<mindspore::MSTensor> outputs,const mindspore::Context * ctx)38 MSKernel(std::vector<mindspore::MSTensor> inputs, std::vector<mindspore::MSTensor> outputs, 39 const mindspore::Context *ctx) 40 : context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)) {} 41 /// \brief Destructor. 42 virtual ~MSKernel() = default; 43 /// \brief infer shape, datatype and format for output tensor of kernel. 44 /// 45 /// \return result code. InferShape()46 virtual int InferShape() { return kLiteError; } 47 /// \brief prepare for executing kernel. 48 /// 49 /// \return result code. 50 virtual int Prepare() = 0; 51 /// \brief execute the kernel. 52 /// 53 /// \return result code. 54 virtual int Execute() = 0; 55 /// \brief resize the kernel input shape, memory need to refresh. 56 /// 57 /// \return result code. 58 virtual int ReSize() = 0; 59 /// \brief set kernel's input tensors. 60 /// 61 /// \param[in] in_tensors define the input tensors. set_inputs(const std::vector<mindspore::MSTensor> & in_tensors)62 virtual void set_inputs(const std::vector<mindspore::MSTensor> &in_tensors) { this->inputs_ = in_tensors; } 63 /// \brief set kernel's input tensor. 64 /// 65 /// \param[in] in_tensor define the input tensor. 66 /// \param[in] index define the index of the input tensor. set_input(mindspore::MSTensor in_tensor,int index)67 virtual void set_input(mindspore::MSTensor in_tensor, int index) { this->inputs_[index] = in_tensor; } 68 /// \brief set kernel's output tensors. 69 /// 70 /// \param[in] out_tensors define the output tensors. set_outputs(const std::vector<mindspore::MSTensor> & out_tensors)71 virtual void set_outputs(const std::vector<mindspore::MSTensor> &out_tensors) { this->outputs_ = out_tensors; } 72 /// \brief set kernel's output tensor. 73 /// 74 /// \param[in] out_tensor define the output tensor. 75 /// \param[in] index define the index of the output tensor. set_output(mindspore::MSTensor out_tensor,int index)76 virtual void set_output(mindspore::MSTensor out_tensor, int index) { this->outputs_[index] = out_tensor; } 77 /// \brief obtain kernel's input tensors. 78 /// 79 /// \return input tensors. inputs()80 virtual const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; } 81 /// \brief obtain kernel's output tensors. 82 /// 83 /// \return output tensors. outputs()84 virtual const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; } 85 /// \brief obtain kernel's name. 86 /// 87 /// \return kernel's name. name()88 virtual std::string name() const { return this->name_; } 89 /// \brief set kernel's name. 90 /// 91 /// \param[in] name define the kernel's name. set_name(const std::string & name)92 void set_name(const std::string &name) { this->name_ = name; } 93 /// \brief obtain kernel's context. 94 /// 95 /// \return kernel's context. context()96 const mindspore::Context *context() const { return this->context_; } 97 98 /// \brief get kernel's attribute. 99 /// 100 /// \param[in] key define the kernel's attribute key. GetAttr(const std::string & key)101 std::string GetAttr(const std::string &key) const { 102 auto iter = attrs_.find(key); 103 if (iter != attrs_.end()) { 104 return iter->second; 105 } 106 return ""; 107 } 108 109 /// \brief set kernel's config. 110 /// 111 /// \param[in] config define the kernel's config. SetConfig(const std::map<std::string,std::map<std::string,std::string>> * config)112 void SetConfig(const std::map<std::string, std::map<std::string, std::string>> *config) { config_ = config; } 113 /// \brief set kernel's config. 114 /// 115 /// \param[in] section define the section of the kernel's config. GetConfig(const std::string & section)116 std::map<std::string, std::string> GetConfig(const std::string §ion) const { 117 if (config_ == nullptr) { 118 return std::map<std::string, std::string>(); 119 } 120 auto iter = config_->find(section); 121 if (iter != config_->end()) { 122 return iter->second; 123 } 124 return std::map<std::string, std::string>(); 125 } 126 127 protected: 128 /// \brief set kernel's attribute 129 /// 130 /// \param[in] key define the kernel's attribute key. 131 /// \param[in] value define the kernel's attribute value. SetAttr(const std::string & key,const std::string & value)132 void SetAttr(const std::string &key, const std::string &value) { attrs_[key] = value; } 133 134 std::string name_; 135 const mindspore::Context *context_ = nullptr; 136 std::vector<mindspore::MSTensor> inputs_; 137 std::vector<mindspore::MSTensor> outputs_; 138 std::map<std::string, std::string> attrs_; 139 const std::map<std::string, std::map<std::string, std::string>> *config_ = nullptr; 140 }; 141 142 /// \brief The Kernel class is used to define a MindSpore Kernel with specific primitive. 143 template <typename Primitive> 144 class MS_API IKernel : public MSKernel { 145 public: 146 IKernel() = default; 147 /// \brief Constructor. 148 /// 149 /// \param[in] inputs define the input tensors for kernel. 150 /// \param[in] outputs define the output tensors for kernel. 151 /// \param[in] primitive define the primitive of kernel. 152 /// \param[in] ctx define the context for kernel. IKernel(const std::vector<mindspore::MSTensor> & inputs,const std::vector<mindspore::MSTensor> & outputs,const Primitive * primitive,const mindspore::Context * ctx)153 IKernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs, 154 const Primitive *primitive, const mindspore::Context *ctx) 155 : MSKernel(inputs, outputs, ctx), primitive_(primitive) {} 156 /// \brief Destructor. 157 ~IKernel() override = default; 158 /// \brief get the primitive of kernel. 159 /// 160 /// \return the primitive of kernel generated by flatbuffers. primitive()161 const Primitive *primitive() const { return this->primitive_; } 162 163 protected: 164 const Primitive *primitive_ = nullptr; 165 }; 166 } // namespace kernel 167 } // namespace mindspore 168 #endif // MINDSPORE_INCLUDE_API_KERNEL_API_H 169