• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/cxx_api/model/model_group_impl.h"
18 #include <memory>
19 #include <algorithm>
20 #include <map>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include "include/api/types.h"
25 #include "include/api/context.h"
26 #include "src/litert/cxx_api/converters.h"
27 #include "src/common/log_adapter.h"
28 #include "src/litert/lite_session.h"
29 #include "src/litert/model_manager.h"
30 #include "src/common/config_file.h"
31 
32 namespace mindspore {
33 using mindspore::lite::RET_OK;
ModelGroupImpl(ModelGroupFlag flags)34 ModelGroupImpl::ModelGroupImpl(ModelGroupFlag flags) : flags_(flags) {
35   static uint32_t g_model_group_id = 0;
36   model_group_id_ = ++g_model_group_id;
37 }
38 
AddModel(const std::vector<std::string> & model_path_list)39 Status ModelGroupImpl::AddModel(const std::vector<std::string> &model_path_list) {
40   if (flags_ != ModelGroupFlag::kShareWorkspace) {
41     MS_LOG(ERROR) << "Only support share workspace for ModelGroup::AddModel(const std::vector<std::string> &)";
42     return kLiteError;
43   }
44   if (model_path_list.empty()) {
45     MS_LOG(ERROR) << "Param model_path_list is empty.";
46     return kLiteParamInvalid;
47   }
48   for (auto &model_path : model_path_list) {
49     if (model_path.empty()) {
50       continue;
51     }
52     (void)model_path_list_.emplace_back(model_path);
53   }
54 
55   return kSuccess;
56 }
57 
AddModel(const std::vector<std::pair<const void *,size_t>> & model_buff_list)58 Status ModelGroupImpl::AddModel(const std::vector<std::pair<const void *, size_t>> &model_buff_list) {
59   if (flags_ != ModelGroupFlag::kShareWorkspace) {
60     MS_LOG(ERROR)
61       << "Only support share workspace for ModelGroup::AddModel(const std::vector<std::pair<const void *, size_t>> &)";
62     return kLiteError;
63   }
64   if (model_buff_list.empty()) {
65     MS_LOG(ERROR) << "Param model_buff_list is empty.";
66     return kLiteParamInvalid;
67   }
68   for (auto &model_buff : model_buff_list) {
69     if (model_buff.first == nullptr || model_buff.second == 0) {
70       continue;
71     }
72     (void)model_buff_list_.emplace_back(model_buff);
73   }
74 
75   return kSuccess;
76 }
77 
CreateLiteSession(const std::shared_ptr<Context> & ms_context)78 lite::LiteSession *ModelGroupImpl::CreateLiteSession(const std::shared_ptr<Context> &ms_context) {
79   auto session = new (std::nothrow) lite::LiteSession();
80   if (session == nullptr) {
81     return nullptr;
82   }
83 
84   std::string sharing_workspace_section = "inner_common";
85   std::string calc_workspace_key = "inner_calc_workspace_size";
86   std::string calc_workspace_value = "true";
87   std::map<std::string, std::string> model_sharing{{calc_workspace_key, calc_workspace_value}};
88   config_info_[sharing_workspace_section] = model_sharing;
89   session->SetConfigInfo(&config_info_);
90   session->SetPrepareSessionFlag(true);
91   auto ret = session->Init(ContextUtils::Convert(ms_context.get()));
92   if (ret != mindspore::lite::RET_OK) {
93     MS_LOG(ERROR) << "init session failed";
94     delete session;
95     return nullptr;
96   }
97   return session;
98 }
99 
CalMaxSizeOfWorkspace(ModelType model_type,const std::shared_ptr<Context> & ms_context)100 Status ModelGroupImpl::CalMaxSizeOfWorkspace(ModelType model_type, const std::shared_ptr<Context> &ms_context) {
101   if (flags_ != ModelGroupFlag::kShareWorkspace) {
102     MS_LOG(ERROR) << "Only support share workspace for ModelGroup::CalMaxSizeOfWorkspace";
103     return kLiteError;
104   }
105   for (auto &model_path : model_path_list_) {
106     auto *session = CreateLiteSession(ms_context);
107     if (session == nullptr) {
108       MS_LOG(ERROR) << "Calculate the maximum workspace size of the model " << model_path << " failed.";
109       ModelManager::GetInstance().ClearModel();
110       return kLiteError;
111     }
112     auto ret = session->LoadModelAndCompileByPath(model_path, model_type);
113     if (ret != mindspore::lite::RET_OK) {
114       MS_LOG(ERROR) << "Calculate the maximum workspace size of the model " << model_path << " failed.";
115       delete session;
116       session = nullptr;
117       ModelManager::GetInstance().ClearModel();
118       return kLiteError;
119     }
120     ModelManager::GetInstance().AddModel(model_path);
121     delete session;
122     session = nullptr;
123   }
124 
125   for (auto &model_buff : model_buff_list_) {
126     auto *session = CreateLiteSession(ms_context);
127     if (session == nullptr) {
128       MS_LOG(ERROR) << "Calculate the maximum workspace size of the model failed.";
129       ModelManager::GetInstance().ClearModel();
130       return kLiteError;
131     }
132     auto ret =
133       session->LoadModelAndCompileByBuf(static_cast<const char *>(model_buff.first), model_type, model_buff.second);
134     if (ret != mindspore::lite::RET_OK) {
135       MS_LOG(ERROR) << "Calculate the maximum workspace size of the model failed.";
136       delete session;
137       session = nullptr;
138       ModelManager::GetInstance().ClearModel();
139       return kLiteError;
140     }
141     ModelManager::GetInstance().AddModel(model_buff);
142     delete session;
143     session = nullptr;
144   }
145   return kSuccess;
146 }
147 }  // namespace mindspore
148