• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_LITE_KERNEL_H_
18 #define MINDSPORE_LITE_SRC_RUNTIME_LITE_KERNEL_H_
19 #include <string>
20 #include <vector>
21 #include <memory>
22 #include <utility>
23 #include <algorithm>
24 #include "src/common/utils.h"
25 #include "src/common/log_util.h"
26 #include "nnacl/op_base.h"
27 #include "src/litert/inner_context.h"
28 #include "src/tensor.h"
29 #include "include/errorcode.h"
30 #include "schema/model_generated.h"
31 #include "src/litert/cxx_api/tensor/tensor_impl.h"
32 #include "include/api/context.h"
33 #include "include/api/kernel.h"
34 #include "src/litert/thread_cost_model.h"
35 #include "extendrt/mindir_loader/abstract_kernel.h"
36 
37 using mindspore::infer::Abstractkernel;
38 
39 namespace mindspore::kernel {
40 class MS_API LiteKernel : public Abstractkernel {
41  public:
42   LiteKernel() = default;
43 
LiteKernel(OpParameter * parameter,std::vector<lite::Tensor * > in_tensors,std::vector<lite::Tensor * > out_tensors,const lite::InnerContext * ctx)44   LiteKernel(OpParameter *parameter, std::vector<lite::Tensor *> in_tensors, std::vector<lite::Tensor *> out_tensors,
45              const lite::InnerContext *ctx)
46       : op_parameter_(parameter),
47         in_tensors_(std::move(in_tensors)),
48         out_tensors_(std::move(out_tensors)),
49         ms_context_(ctx) {
50     if (ctx != nullptr) {
51       thread_num_ = ctx->thread_num_;
52     }
53   }
54 
~LiteKernel()55   virtual ~LiteKernel() {
56     if (op_parameter_ != nullptr) {
57       free(op_parameter_);
58       op_parameter_ = nullptr;
59       FreeWorkspace();
60     }
61   }
62 
63   int Execute() override;
64 
65   int InferShape() override;
66 
Run()67   virtual int Run() { return mindspore::lite::RET_ERROR; }
ReSize()68   int ReSize() override { return mindspore::lite::RET_ERROR; }
69 
70   // called before Run
71   virtual int PreProcess();
72   // called after Run
PostProcess()73   virtual int PostProcess() { return FreeInWorkTensor(); }
74 
CheckInputsValid()75   virtual bool CheckInputsValid() const { return true; }
76 
CheckParamsValid()77   virtual bool CheckParamsValid() const { return true; }
78 
FreeInWorkTensor()79   virtual int FreeInWorkTensor() const {
80     for (auto &in_tensor : this->in_tensors()) {
81       MS_ASSERT(in_tensor != nullptr);
82       in_tensor->DecRefCount();
83     }
84     return lite::RET_OK;
85   }
86 
Prepare()87   int Prepare() override { return mindspore::lite::RET_OK; }
88 
op_parameter()89   OpParameter *op_parameter() const { return op_parameter_; }
90 
set_parameter(OpParameter * param)91   void set_parameter(OpParameter *param) { op_parameter_ = param; }
92 
InferShapeDone()93   virtual bool InferShapeDone() const {
94     auto checker = ms_context_ != nullptr ? static_cast<const lite::InnerContext *>(ms_context_)->get_infer_checker()
95                                           : lite::InferCheckerOutput;
96     return checker != nullptr && checker(in_tensors_, out_tensors_);
97   }
98 
type()99   schema::PrimitiveType type() const override {
100     return (this->op_parameter_ != nullptr) ? schema::PrimitiveType(this->op_parameter_->type_)
101                                             : schema::PrimitiveType_NONE;
102   }
103 
quant_type()104   schema::QuantType quant_type() const override {
105     return (this->op_parameter_ != nullptr) ? schema::QuantType(this->op_parameter_->quant_type_)
106                                             : schema::QuantType_QUANT_NONE;
107   }
108 
inputs()109   const std::vector<mindspore::MSTensor> &inputs() override {
110     if (inputs_.empty()) {
111       std::transform(in_tensors_.begin(), in_tensors_.end(), std::back_inserter(inputs_), [](lite::Tensor *tensor) {
112         return mindspore::MSTensor(std::make_shared<LiteTensorImpl>(tensor));
113       });
114     }
115     return inputs_;
116   }
117 
outputs()118   const std::vector<mindspore::MSTensor> &outputs() override {
119     if (outputs_.empty()) {
120       std::transform(out_tensors_.begin(), out_tensors_.end(), std::back_inserter(outputs_), [](lite::Tensor *tensor) {
121         return mindspore::MSTensor(std::make_shared<LiteTensorImpl>(tensor));
122       });
123     }
124     return outputs_;
125   }
126 
set_in_tensors(const std::vector<lite::Tensor * > & in_tensors)127   void set_in_tensors(const std::vector<lite::Tensor *> &in_tensors) override { this->in_tensors_ = in_tensors; }
128 
set_in_tensor(lite::Tensor * in_tensor,size_t index)129   void set_in_tensor(lite::Tensor *in_tensor, size_t index) override {
130     if (index >= in_tensors_.size()) {
131       MS_LOG(ERROR) << "index: " << index << " larger than in_tensors size: " << in_tensors_.size();
132       return;
133     }
134     this->in_tensors_[index] = in_tensor;
135   }
136 
set_out_tensors(const std::vector<lite::Tensor * > & out_tensors)137   void set_out_tensors(const std::vector<lite::Tensor *> &out_tensors) override { this->out_tensors_ = out_tensors; }
138 
set_out_tensor(lite::Tensor * out_tensor,size_t index)139   void set_out_tensor(lite::Tensor *out_tensor, size_t index) override {
140     if (index >= out_tensors_.size()) {
141       MS_LOG(ERROR) << "index: " << index << " larger than out_tensors size: " << out_tensors_.size();
142       return;
143     }
144     this->out_tensors_[index] = out_tensor;
145   }
146 
in_tensors()147   const std::vector<lite::Tensor *> &in_tensors() const override { return in_tensors_; }
148 
out_tensors()149   const std::vector<lite::Tensor *> &out_tensors() const override { return out_tensors_; }
150 
Train()151   int Train() override {
152     this->train_mode_ = true;
153     return mindspore::lite::RET_OK;
154   }
155 
IsTrain()156   bool IsTrain() const override { return this->train_mode_; }
157 
Eval()158   int Eval() override {
159     this->train_mode_ = false;
160     return mindspore::lite::RET_OK;
161   }
162 
SetupVirtualBatch(int,int)163   virtual int SetupVirtualBatch(int, int) { return mindspore::lite::RET_OK; }
164 
IsEval()165   bool IsEval() const override { return !this->train_mode_; }
166 
SetTrainable(bool trainable)167   void SetTrainable(bool trainable) override { this->trainable_ = trainable; }
168 
IsTrainable()169   bool IsTrainable() const override { return this->trainable_; }
170 
registry_data_type(void)171   TypeId registry_data_type(void) const { return registry_data_type_; }
172 
set_registry_data_type(TypeId data_type)173   void set_registry_data_type(TypeId data_type) { registry_data_type_ = data_type; }
174 
set_workspace_size(size_t value)175   void set_workspace_size(size_t value) { workspace_size_ = value; }
workspace_size()176   virtual size_t workspace_size() { return workspace_size_; }
177   void AllocWorkspace();
178   void FreeWorkspace();
workspace()179   void *workspace() const { return workspace_; }
set_workspace(void * ws)180   void set_workspace(void *ws) {
181     if (ws_allocated_ == false) {
182       workspace_ = ws;
183     }
184   }
185   bool ws_allocated_ = false;
186 
PreparePackedWeight(const lite::Tensor * tensor)187   virtual int PreparePackedWeight(const lite::Tensor *tensor) { return mindspore::lite::RET_OK; }
188 
189  protected:
190   virtual int UpdateThreadNumProcess(int32_t kernel_type, int64_t per_unit_load_num, int64_t per_unit_store_num,
191                                      int64_t unit_num);
192   int UpdateThreadNumPass(int32_t kernel_type, int64_t per_unit_load_num, int64_t per_unit_store_num, int64_t unit_num);
193 
194  protected:
195   OpParameter *op_parameter_ = nullptr;
196   // tensor will free in ~lite_session()
197   std::vector<lite::Tensor *> in_tensors_;
198   std::vector<lite::Tensor *> out_tensors_;
199   bool train_mode_ = false;
200   bool trainable_ = false;  // parameters of this Kernel are trained in Train Session
201   TypeId registry_data_type_ = kTypeUnknown;
202   size_t workspace_size_ = 0;
203   void *workspace_ = nullptr;
204   const lite::InnerContext *ms_context_ = nullptr;
205 
206   int thread_num_ = 1;
207 };
208 }  // namespace mindspore::kernel
209 
210 #endif  // MINDSPORE_LITE_SRC_RUNTIME_LITE_KERNEL_H_
211