1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/cxx_api/model/model_group_impl.h"
18 #include <memory>
19 #include <algorithm>
20 #include <map>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include "include/api/types.h"
25 #include "include/api/context.h"
26 #include "src/litert/cxx_api/converters.h"
27 #include "src/common/log_adapter.h"
28 #include "src/litert/lite_session.h"
29 #include "src/litert/model_manager.h"
30 #include "src/common/config_file.h"
31
32 namespace mindspore {
33 using mindspore::lite::RET_OK;
ModelGroupImpl(ModelGroupFlag flags)34 ModelGroupImpl::ModelGroupImpl(ModelGroupFlag flags) : flags_(flags) {
35 static uint32_t g_model_group_id = 0;
36 model_group_id_ = ++g_model_group_id;
37 }
38
AddModel(const std::vector<std::string> & model_path_list)39 Status ModelGroupImpl::AddModel(const std::vector<std::string> &model_path_list) {
40 if (flags_ != ModelGroupFlag::kShareWorkspace) {
41 MS_LOG(ERROR) << "Only support share workspace for ModelGroup::AddModel(const std::vector<std::string> &)";
42 return kLiteError;
43 }
44 if (model_path_list.empty()) {
45 MS_LOG(ERROR) << "Param model_path_list is empty.";
46 return kLiteParamInvalid;
47 }
48 for (auto &model_path : model_path_list) {
49 if (model_path.empty()) {
50 continue;
51 }
52 (void)model_path_list_.emplace_back(model_path);
53 }
54
55 return kSuccess;
56 }
57
AddModel(const std::vector<std::pair<const void *,size_t>> & model_buff_list)58 Status ModelGroupImpl::AddModel(const std::vector<std::pair<const void *, size_t>> &model_buff_list) {
59 if (flags_ != ModelGroupFlag::kShareWorkspace) {
60 MS_LOG(ERROR)
61 << "Only support share workspace for ModelGroup::AddModel(const std::vector<std::pair<const void *, size_t>> &)";
62 return kLiteError;
63 }
64 if (model_buff_list.empty()) {
65 MS_LOG(ERROR) << "Param model_buff_list is empty.";
66 return kLiteParamInvalid;
67 }
68 for (auto &model_buff : model_buff_list) {
69 if (model_buff.first == nullptr || model_buff.second == 0) {
70 continue;
71 }
72 (void)model_buff_list_.emplace_back(model_buff);
73 }
74
75 return kSuccess;
76 }
77
CreateLiteSession(const std::shared_ptr<Context> & ms_context)78 lite::LiteSession *ModelGroupImpl::CreateLiteSession(const std::shared_ptr<Context> &ms_context) {
79 auto session = new (std::nothrow) lite::LiteSession();
80 if (session == nullptr) {
81 return nullptr;
82 }
83
84 std::string sharing_workspace_section = "inner_common";
85 std::string calc_workspace_key = "inner_calc_workspace_size";
86 std::string calc_workspace_value = "true";
87 std::map<std::string, std::string> model_sharing{{calc_workspace_key, calc_workspace_value}};
88 config_info_[sharing_workspace_section] = model_sharing;
89 session->SetConfigInfo(&config_info_);
90 session->SetPrepareSessionFlag(true);
91 auto ret = session->Init(ContextUtils::Convert(ms_context.get()));
92 if (ret != mindspore::lite::RET_OK) {
93 MS_LOG(ERROR) << "init session failed";
94 delete session;
95 return nullptr;
96 }
97 return session;
98 }
99
CalMaxSizeOfWorkspace(ModelType model_type,const std::shared_ptr<Context> & ms_context)100 Status ModelGroupImpl::CalMaxSizeOfWorkspace(ModelType model_type, const std::shared_ptr<Context> &ms_context) {
101 if (flags_ != ModelGroupFlag::kShareWorkspace) {
102 MS_LOG(ERROR) << "Only support share workspace for ModelGroup::CalMaxSizeOfWorkspace";
103 return kLiteError;
104 }
105 for (auto &model_path : model_path_list_) {
106 auto *session = CreateLiteSession(ms_context);
107 if (session == nullptr) {
108 MS_LOG(ERROR) << "Calculate the maximum workspace size of the model " << model_path << " failed.";
109 ModelManager::GetInstance().ClearModel();
110 return kLiteError;
111 }
112 auto ret = session->LoadModelAndCompileByPath(model_path, model_type);
113 if (ret != mindspore::lite::RET_OK) {
114 MS_LOG(ERROR) << "Calculate the maximum workspace size of the model " << model_path << " failed.";
115 delete session;
116 session = nullptr;
117 ModelManager::GetInstance().ClearModel();
118 return kLiteError;
119 }
120 ModelManager::GetInstance().AddModel(model_path);
121 delete session;
122 session = nullptr;
123 }
124
125 for (auto &model_buff : model_buff_list_) {
126 auto *session = CreateLiteSession(ms_context);
127 if (session == nullptr) {
128 MS_LOG(ERROR) << "Calculate the maximum workspace size of the model failed.";
129 ModelManager::GetInstance().ClearModel();
130 return kLiteError;
131 }
132 auto ret =
133 session->LoadModelAndCompileByBuf(static_cast<const char *>(model_buff.first), model_type, model_buff.second);
134 if (ret != mindspore::lite::RET_OK) {
135 MS_LOG(ERROR) << "Calculate the maximum workspace size of the model failed.";
136 delete session;
137 session = nullptr;
138 ModelManager::GetInstance().ClearModel();
139 return kLiteError;
140 }
141 ModelManager::GetInstance().AddModel(model_buff);
142 delete session;
143 session = nullptr;
144 }
145 return kSuccess;
146 }
147 } // namespace mindspore
148