• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_INNER_KERNEL_H_
18 #define MINDSPORE_LITE_SRC_INNER_KERNEL_H_
19 #include <string>
20 #include <vector>
21 #include <memory>
22 #include <utility>
23 #include <algorithm>
24 #include "src/common/utils.h"
25 #include "src/common/log_util.h"
26 #include "nnacl/op_base.h"
27 #include "src/inner_context.h"
28 #include "src/tensor.h"
29 #include "include/errorcode.h"
30 #include "schema/model_generated.h"
31 #include "src/cxx_api/tensor/tensor_impl.h"
32 #include "include/api/context.h"
33 #include "include/api/kernel.h"
34 
35 namespace mindspore::kernel {
36 class InnerKernel : public Kernel {
37  public:
38   InnerKernel() = default;
39 
InnerKernel(OpParameter * parameter,std::vector<lite::Tensor * > in_tensors,std::vector<lite::Tensor * > out_tensors,const lite::Context * ctx)40   InnerKernel(OpParameter *parameter, std::vector<lite::Tensor *> in_tensors, std::vector<lite::Tensor *> out_tensors,
41               const lite::Context *ctx)
42       : op_parameter_(parameter),
43         in_tensors_(std::move(in_tensors)),
44         out_tensors_(std::move(out_tensors)),
45         ms_context_(ctx) {}
46 
~InnerKernel()47   virtual ~InnerKernel() {
48     if (op_parameter_ != nullptr) {
49       free(op_parameter_);
50       op_parameter_ = nullptr;
51       FreeWorkspace();
52     }
53   }
54 
55   int Execute() override;
56 
57   // called while compiling graph
Prepare()58   int Prepare() override { return mindspore::lite::RET_OK; }
Run()59   virtual int Run() { return mindspore::lite::RET_ERROR; }
ReSize()60   int ReSize() override { return mindspore::lite::RET_ERROR; }
61 
62   // called before Run
63   virtual int PreProcess();
64   // called after Run
PostProcess()65   virtual int PostProcess() { return FreeInWorkTensor(); }
66 
FreeInWorkTensor()67   virtual int FreeInWorkTensor() const {
68     for (auto &in_tensor : this->in_tensors()) {
69       MS_ASSERT(in_tensor != nullptr);
70       in_tensor->DecRefCount();
71     }
72     return lite::RET_OK;
73   }
74 
Init()75   virtual int Init() { return mindspore::lite::RET_OK; }
76 
op_parameter()77   OpParameter *op_parameter() const { return op_parameter_; }
78 
InferShapeDone()79   bool InferShapeDone() const {
80     if (std::any_of(in_tensors_.begin(), in_tensors_.end(),
81                     [](lite::Tensor *input) { return input->data_type() == kObjectTypeTensorType; })) {
82       return false;
83     }
84     auto shape = out_tensors_.front()->shape();
85     if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
86       return false;
87     }
88     return true;
89   }
90 
type()91   schema::PrimitiveType type() const override {
92     return (this->op_parameter_ != nullptr) ? schema::PrimitiveType(this->op_parameter_->type_)
93                                             : schema::PrimitiveType_NONE;
94   }
95 
inputs()96   const std::vector<mindspore::MSTensor> &inputs() override {
97     if (inputs_.empty()) {
98       std::transform(in_tensors_.begin(), in_tensors_.end(), std::back_inserter(inputs_), [](lite::Tensor *tensor) {
99         return mindspore::MSTensor(std::make_shared<mindspore::MSTensor::Impl>(tensor));
100       });
101     }
102     return inputs_;
103   }
104 
outputs()105   const std::vector<mindspore::MSTensor> &outputs() override {
106     if (outputs_.empty()) {
107       std::transform(out_tensors_.begin(), out_tensors_.end(), std::back_inserter(outputs_), [](lite::Tensor *tensor) {
108         return mindspore::MSTensor(std::make_shared<mindspore::MSTensor::Impl>(tensor));
109       });
110     }
111     return outputs_;
112   }
113 
set_in_tensors(const std::vector<lite::Tensor * > & in_tensors)114   void set_in_tensors(const std::vector<lite::Tensor *> &in_tensors) { this->in_tensors_ = in_tensors; }
115 
set_in_tensor(lite::Tensor * in_tensor,size_t index)116   virtual void set_in_tensor(lite::Tensor *in_tensor, size_t index) {
117     if (index >= in_tensors_.size()) {
118       MS_LOG(ERROR) << "index: " << index << " larger than in_tensors size: " << in_tensors_.size();
119       return;
120     }
121     this->in_tensors_[index] = in_tensor;
122   }
123 
set_out_tensors(const std::vector<lite::Tensor * > & out_tensors)124   void set_out_tensors(const std::vector<lite::Tensor *> &out_tensors) { this->out_tensors_ = out_tensors; }
125 
set_out_tensor(lite::Tensor * out_tensor,size_t index)126   virtual void set_out_tensor(lite::Tensor *out_tensor, size_t index) {
127     if (index >= out_tensors_.size()) {
128       MS_LOG(ERROR) << "index: " << index << " larger than out_tensors size: " << out_tensors_.size();
129       return;
130     }
131     this->out_tensors_[index] = out_tensor;
132   }
133 
in_tensors()134   const std::vector<lite::Tensor *> &in_tensors() const { return in_tensors_; }
135 
out_tensors()136   const std::vector<lite::Tensor *> &out_tensors() const { return out_tensors_; }
137 
Train()138   virtual int Train() {
139     this->train_mode_ = true;
140     return mindspore::lite::RET_OK;
141   }
142 
IsTrain()143   virtual bool IsTrain() const { return this->train_mode_; }
144 
Eval()145   virtual int Eval() {
146     this->train_mode_ = false;
147     return mindspore::lite::RET_OK;
148   }
149 
IsEval()150   virtual bool IsEval() const { return !this->train_mode_; }
151 
152   virtual void SetTrainable(bool trainable = true) { this->trainable_ = trainable; }
153 
IsTrainable()154   virtual bool IsTrainable() const { return this->trainable_; }
155 
registry_data_type(void)156   TypeId registry_data_type(void) { return registry_data_type_; }
157 
set_registry_data_type(TypeId data_type)158   void set_registry_data_type(TypeId data_type) { registry_data_type_ = data_type; }
159 
set_workspace_size(size_t value)160   void set_workspace_size(size_t value) { workspace_size_ = value; }
workspace_size()161   virtual size_t workspace_size() { return workspace_size_; }
162   void AllocWorkspace();
163   void FreeWorkspace();
workspace()164   void *workspace() { return workspace_; }
set_workspace(void * ws)165   void set_workspace(void *ws) {
166     if (ws_allocated_ == false) {
167       workspace_ = ws;
168     }
169   }
context()170   const lite::Context *context() const { return this->ms_context_; }
171   bool ws_allocated_ = false;
172 
173  protected:
174   OpParameter *op_parameter_ = nullptr;
175   // tensor will free in ~lite_session()
176   std::vector<lite::Tensor *> in_tensors_;
177   std::vector<lite::Tensor *> out_tensors_;
178   bool train_mode_ = false;
179   bool trainable_ = false;  // parameters of this Kernel are trained in Train Session
180   TypeId registry_data_type_ = kTypeUnknown;
181   size_t workspace_size_ = 0;
182   void *workspace_ = nullptr;
183   const lite::Context *ms_context_ = nullptr;
184 
185   int thread_num_ = 1;
186 };
187 }  // namespace mindspore::kernel
188 
189 #endif  // MINDSPORE_LITE_SRC_INNER_KERNEL_H_
190