1 /** 2 * Copyright 2021-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_LITE_KERNEL_H_ 18 #define MINDSPORE_LITE_SRC_RUNTIME_LITE_KERNEL_H_ 19 #include <string> 20 #include <vector> 21 #include <memory> 22 #include <utility> 23 #include <algorithm> 24 #include "src/common/utils.h" 25 #include "src/common/log_util.h" 26 #include "nnacl/op_base.h" 27 #include "src/litert/inner_context.h" 28 #include "src/tensor.h" 29 #include "include/errorcode.h" 30 #include "schema/model_generated.h" 31 #include "src/litert/cxx_api/tensor/tensor_impl.h" 32 #include "include/api/context.h" 33 #include "include/api/kernel.h" 34 #include "src/litert/thread_cost_model.h" 35 #include "extendrt/mindir_loader/abstract_kernel.h" 36 37 using mindspore::infer::Abstractkernel; 38 39 namespace mindspore::kernel { 40 class MS_API LiteKernel : public Abstractkernel { 41 public: 42 LiteKernel() = default; 43 LiteKernel(OpParameter * parameter,std::vector<lite::Tensor * > in_tensors,std::vector<lite::Tensor * > out_tensors,const lite::InnerContext * ctx)44 LiteKernel(OpParameter *parameter, std::vector<lite::Tensor *> in_tensors, std::vector<lite::Tensor *> out_tensors, 45 const lite::InnerContext *ctx) 46 : op_parameter_(parameter), 47 in_tensors_(std::move(in_tensors)), 48 out_tensors_(std::move(out_tensors)), 49 ms_context_(ctx) { 50 if (ctx != nullptr) { 51 thread_num_ = ctx->thread_num_; 52 } 53 } 54 ~LiteKernel()55 virtual ~LiteKernel() { 56 if (op_parameter_ != nullptr) { 57 free(op_parameter_); 58 op_parameter_ = nullptr; 59 FreeWorkspace(); 60 } 61 } 62 63 int Execute() override; 64 65 int InferShape() override; 66 Run()67 virtual int Run() { return mindspore::lite::RET_ERROR; } ReSize()68 int ReSize() override { return mindspore::lite::RET_ERROR; } 69 70 // called before Run 71 virtual int PreProcess(); 72 // called after Run PostProcess()73 virtual int PostProcess() { return FreeInWorkTensor(); } 74 CheckInputsValid()75 virtual bool CheckInputsValid() const { return true; } 76 CheckParamsValid()77 virtual bool CheckParamsValid() const { return true; } 78 FreeInWorkTensor()79 virtual int FreeInWorkTensor() const { 80 for (auto &in_tensor : this->in_tensors()) { 81 MS_ASSERT(in_tensor != nullptr); 82 in_tensor->DecRefCount(); 83 } 84 return lite::RET_OK; 85 } 86 Prepare()87 int Prepare() override { return mindspore::lite::RET_OK; } 88 op_parameter()89 OpParameter *op_parameter() const { return op_parameter_; } 90 set_parameter(OpParameter * param)91 void set_parameter(OpParameter *param) { op_parameter_ = param; } 92 InferShapeDone()93 virtual bool InferShapeDone() const { 94 auto checker = ms_context_ != nullptr ? static_cast<const lite::InnerContext *>(ms_context_)->get_infer_checker() 95 : lite::InferCheckerOutput; 96 return checker != nullptr && checker(in_tensors_, out_tensors_); 97 } 98 type()99 schema::PrimitiveType type() const override { 100 return (this->op_parameter_ != nullptr) ? schema::PrimitiveType(this->op_parameter_->type_) 101 : schema::PrimitiveType_NONE; 102 } 103 quant_type()104 schema::QuantType quant_type() const override { 105 return (this->op_parameter_ != nullptr) ? schema::QuantType(this->op_parameter_->quant_type_) 106 : schema::QuantType_QUANT_NONE; 107 } 108 inputs()109 const std::vector<mindspore::MSTensor> &inputs() override { 110 if (inputs_.empty()) { 111 std::transform(in_tensors_.begin(), in_tensors_.end(), std::back_inserter(inputs_), [](lite::Tensor *tensor) { 112 return mindspore::MSTensor(std::make_shared<LiteTensorImpl>(tensor)); 113 }); 114 } 115 return inputs_; 116 } 117 outputs()118 const std::vector<mindspore::MSTensor> &outputs() override { 119 if (outputs_.empty()) { 120 std::transform(out_tensors_.begin(), out_tensors_.end(), std::back_inserter(outputs_), [](lite::Tensor *tensor) { 121 return mindspore::MSTensor(std::make_shared<LiteTensorImpl>(tensor)); 122 }); 123 } 124 return outputs_; 125 } 126 set_in_tensors(const std::vector<lite::Tensor * > & in_tensors)127 void set_in_tensors(const std::vector<lite::Tensor *> &in_tensors) override { this->in_tensors_ = in_tensors; } 128 set_in_tensor(lite::Tensor * in_tensor,size_t index)129 void set_in_tensor(lite::Tensor *in_tensor, size_t index) override { 130 if (index >= in_tensors_.size()) { 131 MS_LOG(ERROR) << "index: " << index << " larger than in_tensors size: " << in_tensors_.size(); 132 return; 133 } 134 this->in_tensors_[index] = in_tensor; 135 } 136 set_out_tensors(const std::vector<lite::Tensor * > & out_tensors)137 void set_out_tensors(const std::vector<lite::Tensor *> &out_tensors) override { this->out_tensors_ = out_tensors; } 138 set_out_tensor(lite::Tensor * out_tensor,size_t index)139 void set_out_tensor(lite::Tensor *out_tensor, size_t index) override { 140 if (index >= out_tensors_.size()) { 141 MS_LOG(ERROR) << "index: " << index << " larger than out_tensors size: " << out_tensors_.size(); 142 return; 143 } 144 this->out_tensors_[index] = out_tensor; 145 } 146 in_tensors()147 const std::vector<lite::Tensor *> &in_tensors() const override { return in_tensors_; } 148 out_tensors()149 const std::vector<lite::Tensor *> &out_tensors() const override { return out_tensors_; } 150 Train()151 int Train() override { 152 this->train_mode_ = true; 153 return mindspore::lite::RET_OK; 154 } 155 IsTrain()156 bool IsTrain() const override { return this->train_mode_; } 157 Eval()158 int Eval() override { 159 this->train_mode_ = false; 160 return mindspore::lite::RET_OK; 161 } 162 SetupVirtualBatch(int,int)163 virtual int SetupVirtualBatch(int, int) { return mindspore::lite::RET_OK; } 164 IsEval()165 bool IsEval() const override { return !this->train_mode_; } 166 SetTrainable(bool trainable)167 void SetTrainable(bool trainable) override { this->trainable_ = trainable; } 168 IsTrainable()169 bool IsTrainable() const override { return this->trainable_; } 170 registry_data_type(void)171 TypeId registry_data_type(void) const { return registry_data_type_; } 172 set_registry_data_type(TypeId data_type)173 void set_registry_data_type(TypeId data_type) { registry_data_type_ = data_type; } 174 set_workspace_size(size_t value)175 void set_workspace_size(size_t value) { workspace_size_ = value; } workspace_size()176 virtual size_t workspace_size() { return workspace_size_; } 177 void AllocWorkspace(); 178 void FreeWorkspace(); workspace()179 void *workspace() const { return workspace_; } set_workspace(void * ws)180 void set_workspace(void *ws) { 181 if (ws_allocated_ == false) { 182 workspace_ = ws; 183 } 184 } 185 bool ws_allocated_ = false; 186 PreparePackedWeight(const lite::Tensor * tensor)187 virtual int PreparePackedWeight(const lite::Tensor *tensor) { return mindspore::lite::RET_OK; } 188 189 protected: 190 virtual int UpdateThreadNumProcess(int32_t kernel_type, int64_t per_unit_load_num, int64_t per_unit_store_num, 191 int64_t unit_num); 192 int UpdateThreadNumPass(int32_t kernel_type, int64_t per_unit_load_num, int64_t per_unit_store_num, int64_t unit_num); 193 194 protected: 195 OpParameter *op_parameter_ = nullptr; 196 // tensor will free in ~lite_session() 197 std::vector<lite::Tensor *> in_tensors_; 198 std::vector<lite::Tensor *> out_tensors_; 199 bool train_mode_ = false; 200 bool trainable_ = false; // parameters of this Kernel are trained in Train Session 201 TypeId registry_data_type_ = kTypeUnknown; 202 size_t workspace_size_ = 0; 203 void *workspace_ = nullptr; 204 const lite::InnerContext *ms_context_ = nullptr; 205 206 int thread_num_ = 1; 207 }; 208 } // namespace mindspore::kernel 209 210 #endif // MINDSPORE_LITE_SRC_RUNTIME_LITE_KERNEL_H_ 211