• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_INCLUDE_API_KERNEL_API_H
18 #define MINDSPORE_INCLUDE_API_KERNEL_API_H
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <map>
23 #include "include/api/types.h"
24 #include "include/api/status.h"
25 namespace mindspore {
26 class Context;
27 namespace kernel {
28 /// \brief The Kernel class is used to define a MindSpore Kernel.
29 class MS_API MSKernel {
30  public:
31   MSKernel() = default;
32   /// \brief Constructor.
33   ///
34   /// \param[in] inputs define the input tensors for kernel.
35   /// \param[in] outputs define the output tensors for kernel.
36   /// \param[in] primitive define the primitive of kernel.
37   /// \param[in] ctx define the context for kernel.
MSKernel(std::vector<mindspore::MSTensor> inputs,std::vector<mindspore::MSTensor> outputs,const mindspore::Context * ctx)38   MSKernel(std::vector<mindspore::MSTensor> inputs, std::vector<mindspore::MSTensor> outputs,
39            const mindspore::Context *ctx)
40       : context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)) {}
41   /// \brief Destructor.
42   virtual ~MSKernel() = default;
43   /// \brief infer shape, datatype and format for output tensor of kernel.
44   ///
45   /// \return result code.
InferShape()46   virtual int InferShape() { return kLiteError; }
47   /// \brief prepare for executing kernel.
48   ///
49   /// \return result code.
50   virtual int Prepare() = 0;
51   /// \brief execute the kernel.
52   ///
53   /// \return result code.
54   virtual int Execute() = 0;
55   /// \brief resize the kernel input shape, memory need to refresh.
56   ///
57   /// \return result code.
58   virtual int ReSize() = 0;
59   /// \brief set kernel's input tensors.
60   ///
61   /// \param[in] in_tensors define the input tensors.
set_inputs(const std::vector<mindspore::MSTensor> & in_tensors)62   virtual void set_inputs(const std::vector<mindspore::MSTensor> &in_tensors) { this->inputs_ = in_tensors; }
63   /// \brief set kernel's input tensor.
64   ///
65   /// \param[in] in_tensor define the input tensor.
66   /// \param[in] index define the index of the input tensor.
set_input(mindspore::MSTensor in_tensor,int index)67   virtual void set_input(mindspore::MSTensor in_tensor, int index) { this->inputs_[index] = in_tensor; }
68   /// \brief set kernel's output tensors.
69   ///
70   /// \param[in] out_tensors define the output tensors.
set_outputs(const std::vector<mindspore::MSTensor> & out_tensors)71   virtual void set_outputs(const std::vector<mindspore::MSTensor> &out_tensors) { this->outputs_ = out_tensors; }
72   /// \brief set kernel's output tensor.
73   ///
74   /// \param[in] out_tensor define the output tensor.
75   /// \param[in] index define the index of the output tensor.
set_output(mindspore::MSTensor out_tensor,int index)76   virtual void set_output(mindspore::MSTensor out_tensor, int index) { this->outputs_[index] = out_tensor; }
77   /// \brief obtain kernel's input tensors.
78   ///
79   /// \return input tensors.
inputs()80   virtual const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; }
81   /// \brief obtain kernel's output tensors.
82   ///
83   /// \return output tensors.
outputs()84   virtual const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; }
85   /// \brief obtain kernel's name.
86   ///
87   /// \return kernel's name.
name()88   virtual std::string name() const { return this->name_; }
89   /// \brief set kernel's name.
90   ///
91   /// \param[in] name define the kernel's name.
set_name(const std::string & name)92   void set_name(const std::string &name) { this->name_ = name; }
93   /// \brief obtain kernel's context.
94   ///
95   /// \return kernel's context.
context()96   const mindspore::Context *context() const { return this->context_; }
97 
98   /// \brief get kernel's attribute.
99   ///
100   /// \param[in] key define the kernel's attribute key.
GetAttr(const std::string & key)101   std::string GetAttr(const std::string &key) const {
102     auto iter = attrs_.find(key);
103     if (iter != attrs_.end()) {
104       return iter->second;
105     }
106     return "";
107   }
108 
109   /// \brief set kernel's config.
110   ///
111   /// \param[in] config define the kernel's config.
SetConfig(const std::map<std::string,std::map<std::string,std::string>> * config)112   void SetConfig(const std::map<std::string, std::map<std::string, std::string>> *config) { config_ = config; }
113   /// \brief set kernel's config.
114   ///
115   /// \param[in] section define the section of the kernel's config.
GetConfig(const std::string & section)116   std::map<std::string, std::string> GetConfig(const std::string &section) const {
117     if (config_ == nullptr) {
118       return std::map<std::string, std::string>();
119     }
120     auto iter = config_->find(section);
121     if (iter != config_->end()) {
122       return iter->second;
123     }
124     return std::map<std::string, std::string>();
125   }
126 
127  protected:
128   /// \brief set kernel's attribute
129   ///
130   /// \param[in] key define the kernel's attribute key.
131   /// \param[in] value define the kernel's attribute value.
SetAttr(const std::string & key,const std::string & value)132   void SetAttr(const std::string &key, const std::string &value) { attrs_[key] = value; }
133 
134   std::string name_;
135   const mindspore::Context *context_ = nullptr;
136   std::vector<mindspore::MSTensor> inputs_;
137   std::vector<mindspore::MSTensor> outputs_;
138   std::map<std::string, std::string> attrs_;
139   const std::map<std::string, std::map<std::string, std::string>> *config_ = nullptr;
140 };
141 
142 /// \brief The Kernel class is used to define a MindSpore Kernel with specific primitive.
143 template <typename Primitive>
144 class MS_API IKernel : public MSKernel {
145  public:
146   IKernel() = default;
147   /// \brief Constructor.
148   ///
149   /// \param[in] inputs define the input tensors for kernel.
150   /// \param[in] outputs define the output tensors for kernel.
151   /// \param[in] primitive define the primitive of kernel.
152   /// \param[in] ctx define the context for kernel.
IKernel(const std::vector<mindspore::MSTensor> & inputs,const std::vector<mindspore::MSTensor> & outputs,const Primitive * primitive,const mindspore::Context * ctx)153   IKernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
154           const Primitive *primitive, const mindspore::Context *ctx)
155       : MSKernel(inputs, outputs, ctx), primitive_(primitive) {}
156   /// \brief Destructor.
157   ~IKernel() override = default;
158   /// \brief get the primitive of kernel.
159   ///
160   /// \return the primitive of kernel generated by flatbuffers.
primitive()161   const Primitive *primitive() const { return this->primitive_; }
162 
163  protected:
164   const Primitive *primitive_ = nullptr;
165 };
166 }  // namespace kernel
167 }  // namespace mindspore
168 #endif  // MINDSPORE_INCLUDE_API_KERNEL_API_H
169