1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_LITE_SESSION_H_ 18 #define MINDSPORE_LITE_SRC_LITE_SESSION_H_ 19 20 #include <memory> 21 #include <vector> 22 #include <string> 23 #include <unordered_map> 24 #include <map> 25 #include <atomic> 26 #include "src/lite_kernel.h" 27 #include "include/ms_tensor.h" 28 #include "include/lite_session.h" 29 #include "include/model.h" 30 #include "src/inner_context.h" 31 #include "schema/model_generated.h" 32 #include "src/executor.h" 33 #include "src/tensor.h" 34 #ifndef CONTROLFLOW_TENSORLIST_CLIP 35 #include "src/tensorlist.h" 36 #endif 37 #ifndef DELEGATE_CLIP 38 #include "include/api/delegate.h" 39 #endif 40 #if GPU_OPENCL 41 #include "src/runtime/gpu/opencl/opencl_runtime.h" 42 #endif 43 #include "src/scheduler_cb.h" 44 45 namespace mindspore { 46 namespace lite { 47 class LiteSession : public session::LiteSession { 48 public: 49 LiteSession(); 50 51 ~LiteSession() override; 52 53 static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context); 54 55 int LoadModelAndCompileByBuf(const char *model_buf, size_t buf_size); 56 57 int LoadModelAndCompileByPath(const std::string &model_path); 58 59 virtual int Init(InnerContext *context); 60 61 void BindThread(bool if_bind) override; 62 63 int CompileGraph(Model *model) override; 64 65 std::vector<mindspore::tensor::MSTensor *> GetInputs() const override; 66 67 mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &name) const override; 68 69 int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; 70 71 std::vector<mindspore::tensor::MSTensor *> GetOutputsByNodeName(const std::string &node_name) const override; 72 73 std::vector<std::string> GetOutputTensorNames() const override; 74 75 mindspore::tensor::MSTensor *GetOutputByTensorName(const std::string &tensor_name) const override; 76 77 std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputs() const override; 78 79 int Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs, 80 const std::vector<std::vector<int>> &dims) override; 81 InitExecutionConfig(std::map<std::string,TypeId> * config)82 void InitExecutionConfig(std::map<std::string, TypeId> *config) { execution_plan_ = config; } 83 set_model(Model * model)84 void set_model(Model *model) { this->model_ = model; } 85 get_kernels()86 const std::vector<kernel::LiteKernel *> &get_kernels() const { return this->kernels_; } 87 get_delegate()88 const Delegate *get_delegate() const { return this->delegate_.get(); } 89 90 protected: 91 static void ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lite::Tensor *dst_tensor); 92 93 int ConvertTensorsData(const lite::Model *model, size_t tensor_index, const schema::Tensor *src_tensor, 94 lite::Tensor *dst_tensor); 95 96 lite::Tensor *ConvertTensor(const schema::Tensor &src_tensor); 97 98 int ConvertTensors(const lite::Model *model); 99 100 void InitGraphInOutTensorsMap(const lite::Model *model); 101 102 int IsolateOutputTensor(); 103 104 void InitGraphInputTensors(const lite::Model *model); 105 106 void InitGraphInputMSTensors(); 107 108 void InitGraphOutputTensors(const lite::Model *model); 109 110 void InitGraphInputMap(const lite::Model *model); 111 112 void InitGraphOutputNodeMap(const lite::Model *model); 113 114 void InitGraphOutputTensorMap(const lite::Model *model); 115 116 void AdjustModelOutputTensorInitRefCount(const lite::Model *model); 117 118 int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims); 119 120 int SetAllocatorForDelegateKernels(const kernel::LiteKernel *kernel); 121 122 int PrepareKernels(Model *model); 123 124 static int ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels); 125 126 static void FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kernels); 127 128 private: 129 void ResetInputsShape(const std::vector<std::vector<int>> &dims); 130 131 int InitGPURuntime(); 132 133 bool IsIsolatedSubGraph(kernel::LiteKernel *kernel); 134 135 protected: 136 InnerContext *context_ = nullptr; 137 mindspore::Context *ms_context_ = nullptr; 138 std::vector<kernel::LiteKernel *> kernels_; 139 std::vector<Tensor *> tensors_; 140 // graph input tensors 141 std::vector<Tensor *> inputs_; 142 // graph output tensors 143 std::vector<Tensor *> outputs_; 144 // graph input MSTensors 145 std::vector<mindspore::tensor::MSTensor *> input_vec_; 146 // graph input tensor name -- input tensors 147 std::unordered_map<std::string, mindspore::tensor::MSTensor *> input_map_; 148 // graph output node name -- output tensors 149 std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> output_node_map_; 150 151 std::vector<std::string> output_tensor_names_; 152 // graph output tensor name -- output tensor 153 std::unordered_map<std::string, mindspore::tensor::MSTensor *> output_tensor_map_; 154 std::unordered_map<Tensor *, Tensor *> graph_output_map_; /* <calculate-tensor, graph-output-tensor> */ 155 Executor *executor_ = nullptr; 156 Model *model_ = nullptr; 157 std::atomic<bool> is_running_ = {false}; 158 bool is_train_session_ = false; 159 friend class TransferSession; 160 #if GPU_OPENCL 161 opencl::OpenCLRuntimeInnerWrapper *opencl_runtime_wrapper_{nullptr}; 162 #endif 163 std::unique_ptr<SchedulerCb> sched_cb_; 164 std::shared_ptr<Delegate> delegate_ = nullptr; 165 int delegate_device_type_ = -1; // -1: not specified; 0: CPU; 1: GPU; 2: NPU 166 std::map<std::string, TypeId> *execution_plan_ = nullptr; 167 }; 168 } // namespace lite 169 } // namespace mindspore 170 #endif // MINDSPORE_LITE_SRC_LITE_SESSION_H_ 171