1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MINDIR_MODEL_INNER_KERNEL_H_ 18 #define MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MINDIR_MODEL_INNER_KERNEL_H_ 19 20 #include <utility> 21 #include <vector> 22 #include <memory> 23 24 #include "src/tensor.h" 25 #include "include/errorcode.h" 26 #include "include/api/kernel.h" 27 #include "src/litert/inner_context.h" 28 // #include "include/api/context.h" 29 #include "kernel/kernel.h" 30 #include "extendrt/mindir_loader/abstract_kernel.h" 31 #include "src/extendrt/utils/tensor_utils.h" 32 33 using mindspore::infer::Abstractkernel; 34 35 namespace mindspore::kernel { 36 class InnerKernel : public Abstractkernel { 37 public: 38 InnerKernel() = default; 39 InnerKernel(std::shared_ptr<mindspore::kernel::KernelMod> kernel_mod,mindspore::kernel::BaseOperatorPtr base_operator,std::vector<lite::Tensor * > in_tensors,std::vector<lite::Tensor * > out_tensors,const lite::InnerContext * ctx)40 InnerKernel(std::shared_ptr<mindspore::kernel::KernelMod> kernel_mod, 41 mindspore::kernel::BaseOperatorPtr base_operator, std::vector<lite::Tensor *> in_tensors, 42 std::vector<lite::Tensor *> out_tensors, const lite::InnerContext *ctx) 43 : kernel_mod_(kernel_mod), 44 base_operator_(base_operator), 45 in_tensors_(std::move(in_tensors)), 46 out_tensors_(std::move(out_tensors)), 47 ms_context_(ctx) {} 48 ~InnerKernel()49 virtual ~InnerKernel() {} 50 51 int Prepare() override; 52 53 int Execute() override; 54 55 int ReSize() override; 56 InferShape()57 int InferShape() override { return lite::RET_ERROR; } 58 Train()59 int Train() override { return mindspore::lite::RET_OK; } 60 IsTrain()61 bool IsTrain() const override { return true; } 62 Eval()63 int Eval() override { return mindspore::lite::RET_OK; } 64 IsEval()65 bool IsEval() const override { return true; } 66 67 void SetTrainable(bool trainable = true) override {} 68 IsTrainable()69 bool IsTrainable() const override { return true; } 70 set_in_tensors(const std::vector<lite::Tensor * > & in_tensors)71 void set_in_tensors(const std::vector<lite::Tensor *> &in_tensors) override { this->in_tensors_ = in_tensors; } 72 set_in_tensor(lite::Tensor * in_tensor,size_t index)73 void set_in_tensor(lite::Tensor *in_tensor, size_t index) override { 74 if (index >= in_tensors_.size()) { 75 MS_LOG(ERROR) << "index: " << index << " larger than in_tensors size: " << in_tensors_.size(); 76 return; 77 } 78 this->in_tensors_[index] = in_tensor; 79 } 80 set_out_tensors(const std::vector<lite::Tensor * > & out_tensors)81 void set_out_tensors(const std::vector<lite::Tensor *> &out_tensors) override { this->out_tensors_ = out_tensors; } 82 set_out_tensor(lite::Tensor * out_tensor,size_t index)83 void set_out_tensor(lite::Tensor *out_tensor, size_t index) override { 84 if (index >= out_tensors_.size()) { 85 MS_LOG(ERROR) << "index: " << index << " larger than out_tensors size: " << out_tensors_.size(); 86 return; 87 } 88 this->out_tensors_[index] = out_tensor; 89 } 90 in_tensors()91 const std::vector<lite::Tensor *> &in_tensors() const override { return in_tensors_; } 92 out_tensors()93 const std::vector<lite::Tensor *> &out_tensors() const override { return out_tensors_; } 94 95 private: 96 std::shared_ptr<mindspore::kernel::KernelMod> kernel_mod_ = nullptr; 97 BaseOperatorPtr base_operator_ = nullptr; 98 std::vector<lite::Tensor *> in_tensors_; 99 std::vector<lite::Tensor *> out_tensors_; 100 const mindspore::lite::InnerContext *ms_context_ = nullptr; 101 }; 102 } // namespace mindspore::kernel 103 104 #endif // MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MINDIR_MODEL_INNER_KERNEL_H_ 105