• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_LITERT_CACHE_SESSION_H_
18 #define MINDSPORE_LITE_SRC_LITERT_CACHE_SESSION_H_
19 
20 #include "src/litert/lite_session.h"
21 #include "src/litert/inner_context.h"
22 #include "src/litert/lite_model.h"
23 #include "src/litert/delegate/nnrt/extension_options_parser.h"
24 #include "neural_network_runtime/neural_network_runtime_type.h"
25 #include "neural_network_runtime/neural_network_runtime.h"
26 #include "neural_network_runtime_inner.h"
27 
28 namespace mindspore {
29 namespace lite {
30 class CacheSession : public LiteSession {
31  public:
32   CacheSession() = default;
33   ~CacheSession() override;
34   int Init(const std::shared_ptr<InnerContext> &context) override;
35   int CompileGraph(Model *model) override;
36   int LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type) override;
37   static bool IsKirinNPUWithOnlineInference(size_t device_id);
38   const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size,
39                               bool use_mmap) override;
40   Model* ImportInOutFromBuffer(const char *model_buf, size_t size, bool take_buf,
41                                mindspore::ModelType model_type = mindspore::ModelType::kMindIR_Lite,
42                                const std::string &path = "");
43 
44   template <typename T = schema::MetaGraph>
ConvertInputOutputTensors(const T & meta_graph,LiteGraph & graph_)45   bool ConvertInputOutputTensors(const T &meta_graph, LiteGraph &graph_) {
46     if (meta_graph.allTensors() == nullptr) {
47       MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
48       return false;
49     }
50 
51     graph_.all_tensors_.resize(meta_graph.allTensors()->size());
52     MS_LOG(INFO) << "convert input/output tensors";
53     for (auto i: graph_.input_indices_) {
54       auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i);
55       if (tensor == nullptr) {
56         MS_LOG(ERROR) << i << " the input tensor in metagraph is nullptr";
57         return false;
58       }
59       MS_CHECK_TRUE_RET(tensor->format() >= schema::Format_MIN && tensor->format() <= schema::Format_MAX, false);
60       graph_.all_tensors_[i] = (const_cast<mindspore::schema::Tensor *>(tensor));
61     }
62 
63     for (auto i: graph_.output_indices_) {
64       auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i);
65       if (tensor == nullptr) {
66         MS_LOG(ERROR) << i << " the output tensor in metagraph is nullptr";
67       }
68       MS_CHECK_TRUE_RET(tensor->format() >= schema::Format_MIN && tensor->format() <= schema::Format_MAX, false);
69       graph_.all_tensors_[i] = (const_cast<mindspore::schema::Tensor *>(tensor));
70     }
71     return true;
72   }
73 
74   template <typename T = schema::MetaGraph, typename U = schema::CNode>
GenerateModelInputOutput(const T & meta_graph,LiteGraph & graph_)75   int GenerateModelInputOutput(const T &meta_graph, LiteGraph &graph_) {
76     if (meta_graph.name() != nullptr) {
77       graph_.name_ = meta_graph.name()->c_str();
78     }
79     if (meta_graph.version() != nullptr) {
80       graph_.version_ = meta_graph.version()->c_str();
81     }
82 
83     if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr ||
84         meta_graph.allTensors() == nullptr) {
85       MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
86       return RET_ERROR;
87     }
88 
89     // converterInputOutput
90     auto in_count = meta_graph.inputIndex()->size();
91     for (uint32_t i = 0; i < in_count; ++i) {
92       graph_.input_indices_.push_back(meta_graph.inputIndex()->Get(i));
93     }
94     auto out_count = meta_graph.outputIndex()->size();
95     for (uint32_t i = 0; i < out_count; ++i) {
96       graph_.output_indices_.push_back(meta_graph.outputIndex()->Get(i));
97     }
98 
99     if (!ConvertInputOutputTensors<T>(meta_graph, graph_)) {
100       MS_LOG(ERROR) << "convert tensor failed";
101       return RET_ERROR;
102     }
103     return RET_OK;
104   }
105 
106   int ParseInputOutputFromModelBuffer(const char *model_buf, LiteModel *model);
BindGLTexture2DMemory(const std::map<std::string,unsigned int> & inputGLTexture,std::map<std::string,unsigned int> * outputGLTexture)107   int BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
108                             std::map<std::string, unsigned int> *outputGLTexture) override {
109     return RET_ERROR;
110   }
111 
112  protected:
113   int ScheduleToNNRTKernel();
114   Status CreateFullModelKernel();
115   Status InitNNCompilation(OH_NNCompilation *nn_compilation) const;
116   int ConvertInOutTensors(const lite::Model *model);
117   int InitExecutor() override;
118   std::vector<mindspore::MSTensor> ms_inputs_;
119   std::vector<mindspore::MSTensor> ms_outputs_;
120 
121  private:
122   NNRtDeviceInfo nnrt_device_info_;
123   OH_NNExecutor *nn_executor_{nullptr};
124   nnrt::ExtensionOptions extension_options_;
125 };
126 }  // namespace lite
127 }  // namespace mindspore
128 
129 #endif  // MINDSPORE_LITE_SRC_LITERT_CACHE_SESSION_H_
130