1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_INNER_KERNEL_H_ 18 #define MINDSPORE_LITE_SRC_INNER_KERNEL_H_ 19 #include <string> 20 #include <vector> 21 #include <memory> 22 #include <utility> 23 #include <algorithm> 24 #include "src/common/utils.h" 25 #include "src/common/log_util.h" 26 #include "nnacl/op_base.h" 27 #include "src/inner_context.h" 28 #include "src/tensor.h" 29 #include "include/errorcode.h" 30 #include "schema/model_generated.h" 31 #include "src/cxx_api/tensor/tensor_impl.h" 32 #include "include/api/context.h" 33 #include "include/api/kernel.h" 34 35 namespace mindspore::kernel { 36 class InnerKernel : public Kernel { 37 public: 38 InnerKernel() = default; 39 InnerKernel(OpParameter * parameter,std::vector<lite::Tensor * > in_tensors,std::vector<lite::Tensor * > out_tensors,const lite::Context * ctx)40 InnerKernel(OpParameter *parameter, std::vector<lite::Tensor *> in_tensors, std::vector<lite::Tensor *> out_tensors, 41 const lite::Context *ctx) 42 : op_parameter_(parameter), 43 in_tensors_(std::move(in_tensors)), 44 out_tensors_(std::move(out_tensors)), 45 ms_context_(ctx) {} 46 ~InnerKernel()47 virtual ~InnerKernel() { 48 if (op_parameter_ != nullptr) { 49 free(op_parameter_); 50 op_parameter_ = nullptr; 51 FreeWorkspace(); 52 } 53 } 54 55 int Execute() override; 56 57 // called while compiling graph Prepare()58 int Prepare() override { return mindspore::lite::RET_OK; } Run()59 virtual int Run() { return mindspore::lite::RET_ERROR; } ReSize()60 int ReSize() override { return mindspore::lite::RET_ERROR; } 61 62 // called before Run 63 virtual int PreProcess(); 64 // called after Run PostProcess()65 virtual int PostProcess() { return FreeInWorkTensor(); } 66 FreeInWorkTensor()67 virtual int FreeInWorkTensor() const { 68 for (auto &in_tensor : this->in_tensors()) { 69 MS_ASSERT(in_tensor != nullptr); 70 in_tensor->DecRefCount(); 71 } 72 return lite::RET_OK; 73 } 74 Init()75 virtual int Init() { return mindspore::lite::RET_OK; } 76 op_parameter()77 OpParameter *op_parameter() const { return op_parameter_; } 78 InferShapeDone()79 bool InferShapeDone() const { 80 if (std::any_of(in_tensors_.begin(), in_tensors_.end(), 81 [](lite::Tensor *input) { return input->data_type() == kObjectTypeTensorType; })) { 82 return false; 83 } 84 auto shape = out_tensors_.front()->shape(); 85 if (std::find(shape.begin(), shape.end(), -1) != shape.end()) { 86 return false; 87 } 88 return true; 89 } 90 type()91 schema::PrimitiveType type() const override { 92 return (this->op_parameter_ != nullptr) ? schema::PrimitiveType(this->op_parameter_->type_) 93 : schema::PrimitiveType_NONE; 94 } 95 inputs()96 const std::vector<mindspore::MSTensor> &inputs() override { 97 if (inputs_.empty()) { 98 std::transform(in_tensors_.begin(), in_tensors_.end(), std::back_inserter(inputs_), [](lite::Tensor *tensor) { 99 return mindspore::MSTensor(std::make_shared<mindspore::MSTensor::Impl>(tensor)); 100 }); 101 } 102 return inputs_; 103 } 104 outputs()105 const std::vector<mindspore::MSTensor> &outputs() override { 106 if (outputs_.empty()) { 107 std::transform(out_tensors_.begin(), out_tensors_.end(), std::back_inserter(outputs_), [](lite::Tensor *tensor) { 108 return mindspore::MSTensor(std::make_shared<mindspore::MSTensor::Impl>(tensor)); 109 }); 110 } 111 return outputs_; 112 } 113 set_in_tensors(const std::vector<lite::Tensor * > & in_tensors)114 void set_in_tensors(const std::vector<lite::Tensor *> &in_tensors) { this->in_tensors_ = in_tensors; } 115 set_in_tensor(lite::Tensor * in_tensor,size_t index)116 virtual void set_in_tensor(lite::Tensor *in_tensor, size_t index) { 117 if (index >= in_tensors_.size()) { 118 MS_LOG(ERROR) << "index: " << index << " larger than in_tensors size: " << in_tensors_.size(); 119 return; 120 } 121 this->in_tensors_[index] = in_tensor; 122 } 123 set_out_tensors(const std::vector<lite::Tensor * > & out_tensors)124 void set_out_tensors(const std::vector<lite::Tensor *> &out_tensors) { this->out_tensors_ = out_tensors; } 125 set_out_tensor(lite::Tensor * out_tensor,size_t index)126 virtual void set_out_tensor(lite::Tensor *out_tensor, size_t index) { 127 if (index >= out_tensors_.size()) { 128 MS_LOG(ERROR) << "index: " << index << " larger than out_tensors size: " << out_tensors_.size(); 129 return; 130 } 131 this->out_tensors_[index] = out_tensor; 132 } 133 in_tensors()134 const std::vector<lite::Tensor *> &in_tensors() const { return in_tensors_; } 135 out_tensors()136 const std::vector<lite::Tensor *> &out_tensors() const { return out_tensors_; } 137 Train()138 virtual int Train() { 139 this->train_mode_ = true; 140 return mindspore::lite::RET_OK; 141 } 142 IsTrain()143 virtual bool IsTrain() const { return this->train_mode_; } 144 Eval()145 virtual int Eval() { 146 this->train_mode_ = false; 147 return mindspore::lite::RET_OK; 148 } 149 IsEval()150 virtual bool IsEval() const { return !this->train_mode_; } 151 152 virtual void SetTrainable(bool trainable = true) { this->trainable_ = trainable; } 153 IsTrainable()154 virtual bool IsTrainable() const { return this->trainable_; } 155 registry_data_type(void)156 TypeId registry_data_type(void) { return registry_data_type_; } 157 set_registry_data_type(TypeId data_type)158 void set_registry_data_type(TypeId data_type) { registry_data_type_ = data_type; } 159 set_workspace_size(size_t value)160 void set_workspace_size(size_t value) { workspace_size_ = value; } workspace_size()161 virtual size_t workspace_size() { return workspace_size_; } 162 void AllocWorkspace(); 163 void FreeWorkspace(); workspace()164 void *workspace() { return workspace_; } set_workspace(void * ws)165 void set_workspace(void *ws) { 166 if (ws_allocated_ == false) { 167 workspace_ = ws; 168 } 169 } context()170 const lite::Context *context() const { return this->ms_context_; } 171 bool ws_allocated_ = false; 172 173 protected: 174 OpParameter *op_parameter_ = nullptr; 175 // tensor will free in ~lite_session() 176 std::vector<lite::Tensor *> in_tensors_; 177 std::vector<lite::Tensor *> out_tensors_; 178 bool train_mode_ = false; 179 bool trainable_ = false; // parameters of this Kernel are trained in Train Session 180 TypeId registry_data_type_ = kTypeUnknown; 181 size_t workspace_size_ = 0; 182 void *workspace_ = nullptr; 183 const lite::Context *ms_context_ = nullptr; 184 185 int thread_num_ = 1; 186 }; 187 } // namespace mindspore::kernel 188 189 #endif // MINDSPORE_LITE_SRC_INNER_KERNEL_H_ 190