1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_INCLUDE_API_KERNEL_H 18 #define MINDSPORE_INCLUDE_API_KERNEL_H 19 #include <vector> 20 #include <string> 21 #include <utility> 22 #include <map> 23 #include "schema/model_generated.h" 24 #include "include/api/types.h" 25 #include "include/api/context.h" 26 27 namespace mindspore::kernel { 28 /// \brief The Kernel class is used to define a MindSpore Kernel. 29 class MS_API Kernel { 30 public: 31 Kernel() = default; 32 /// \brief Constructor. 33 /// 34 /// \param[in] inputs define the input tensors for kernel. 35 /// \param[in] outputs define the output tensors for kernel. 36 /// \param[in] primitive define the primitive of kernel generated by flatbuffers. 37 /// \param[in] ctx define the context for kernel. Kernel(const std::vector<mindspore::MSTensor> & inputs,const std::vector<mindspore::MSTensor> & outputs,const schema::Primitive * primitive,const mindspore::Context * ctx)38 Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs, 39 const schema::Primitive *primitive, const mindspore::Context *ctx) 40 : context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)), primitive_(primitive) { 41 Initialize(); 42 } 43 /// \brief Destructor. 44 virtual ~Kernel() = default; 45 /// \brief prepare for executing kernel. 46 /// 47 /// \return result code. 48 virtual int Prepare() = 0; 49 /// \brief execute the kernel. 50 /// 51 /// \return result code. 52 virtual int Execute() = 0; 53 /// \brief resize the kernel input shape, memory need to refresh. 54 /// 55 /// \return result code. 56 virtual int ReSize() = 0; 57 /// \brief set kernel's input tensors. 58 /// 59 /// \param[in] in_tensors define the input tensors. set_inputs(const std::vector<mindspore::MSTensor> & in_tensors)60 virtual void set_inputs(const std::vector<mindspore::MSTensor> &in_tensors) { this->inputs_ = in_tensors; } 61 /// \brief set kernel's input tensor. 62 /// 63 /// \param[in] in_tensor define the input tensor. 64 /// \param[in] index define the index of the input tensor. set_input(mindspore::MSTensor in_tensor,int index)65 virtual void set_input(mindspore::MSTensor in_tensor, int index) { this->inputs_[index] = in_tensor; } 66 /// \brief set kernel's output tensors. 67 /// 68 /// \param[in] out_tensors define the output tensors. set_outputs(const std::vector<mindspore::MSTensor> & out_tensors)69 virtual void set_outputs(const std::vector<mindspore::MSTensor> &out_tensors) { this->outputs_ = out_tensors; } 70 /// \brief set kernel's output tensor. 71 /// 72 /// \param[in] out_tensor define the output tensor. 73 /// \param[in] index define the index of the output tensor. set_output(mindspore::MSTensor out_tensor,int index)74 virtual void set_output(mindspore::MSTensor out_tensor, int index) { this->outputs_[index] = out_tensor; } 75 /// \brief obtain kernel's input tensors. 76 /// 77 /// \return input tensors. inputs()78 virtual const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; } 79 /// \brief obtain kernel's output tensors. 80 /// 81 /// \return output tensors. outputs()82 virtual const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; } 83 /// \brief obtain kernel's name. 84 /// 85 /// \return kernel's name. name()86 std::string name() const { return this->name_; } 87 /// \brief set kernel's name. 88 /// 89 /// \param[in] name define the kernel's name. set_name(const std::string & name)90 void set_name(const std::string &name) { this->name_ = name; } 91 /// \brief obtain kernel's context. 92 /// 93 /// \return kernel's context. context()94 const mindspore::Context *context() const { return this->context_; } 95 /// \brief obtain kernel's type. 96 /// 97 /// \return kernel's type. type()98 virtual schema::PrimitiveType type() const { return type_; } 99 /// \brief obtain the primitive of kernel generated by flatbuffers. 100 /// 101 /// \return the primitive of kernel generated by flatbuffers. primitive()102 const schema::Primitive *primitive() const { return this->primitive_; } 103 /// \brief get kernel's attribute 104 /// 105 /// \param[in] key define the kernel's attribute key. GetAttr(const std::string & key)106 std::string GetAttr(const std::string &key) const { 107 auto iter = attrs_.find(key); 108 if (iter != attrs_.end()) { 109 return iter->second; 110 } 111 return ""; 112 } 113 114 protected: 115 /// \brief set kernel's attribute 116 /// 117 /// \param[in] key define the kernel's attribute key. 118 /// \param[in] value define the kernel's attribute value. SetAttr(const std::string & key,const std::string & value)119 void SetAttr(const std::string &key, const std::string &value) { attrs_[key] = value; } 120 121 protected: 122 std::string name_; 123 const mindspore::Context *context_ = nullptr; 124 std::vector<mindspore::MSTensor> inputs_; 125 std::vector<mindspore::MSTensor> outputs_; 126 schema::PrimitiveType type_ = schema::PrimitiveType_NONE; 127 const schema::Primitive *primitive_ = nullptr; 128 std::map<std::string, std::string> attrs_; 129 130 private: 131 void Initialize(); 132 }; 133 } // namespace mindspore::kernel 134 135 #endif // MINDSPORE_INCLUDE_API_KERNEL_H 136