1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/api/model.h"
18 #include <mutex>
19 #include "include/api/types.h"
20 #include "include/api/context.h"
21 #include "include/api/callback/callback.h"
22 #include "include/api/dual_abi_helper.h"
23 #include "src/cxx_api/model/model_impl.h"
24 #include "src/cxx_api/callback/callback_impl.h"
25 #include "src/cxx_api/callback/callback_adapter.h"
26 #include "src/common/log_adapter.h"
27
28 namespace mindspore {
29 std::mutex g_impl_init_lock;
30
Build(const void * model_data,size_t data_size,ModelType model_type,const std::shared_ptr<Context> & model_context,const Key & dec_key,const std::string & dec_mode)31 Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
32 const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
33 if (impl_ == nullptr) {
34 std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
35 impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
36 if (impl_ == nullptr) {
37 MS_LOG(ERROR) << "Model implement is null.";
38 return kLiteFileError;
39 }
40 }
41
42 Status ret = impl_->Build(model_data, data_size, model_type, model_context);
43 if (ret != kSuccess) {
44 return ret;
45 }
46 return kSuccess;
47 }
48
Build(const std::string & model_path,ModelType model_type,const std::shared_ptr<Context> & model_context,const Key & dec_key,const std::string & dec_mode)49 Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
50 const Key &dec_key, const std::string &dec_mode) {
51 if (impl_ == nullptr) {
52 std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
53 impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
54 if (impl_ == nullptr) {
55 MS_LOG(ERROR) << "Model implement is null.";
56 return kLiteFileError;
57 }
58 }
59
60 Status ret = impl_->Build(model_path, model_type, model_context);
61 if (ret != kSuccess) {
62 return ret;
63 }
64 return kSuccess;
65 }
66
Build(GraphCell graph,const std::shared_ptr<Context> & model_context,const std::shared_ptr<TrainCfg> & train_cfg)67 Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_context,
68 const std::shared_ptr<TrainCfg> &train_cfg) {
69 std::stringstream err_msg;
70 if (impl_ == nullptr) {
71 std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
72 impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
73 if (impl_ == nullptr) {
74 MS_LOG(ERROR) << "Model implement is null.";
75 return kLiteFileError;
76 }
77 }
78
79 if (graph.GetGraph() == nullptr) {
80 err_msg << "Invalid null graph.";
81 MS_LOG(ERROR) << err_msg.str();
82 return Status(kLiteNullptr, err_msg.str());
83 }
84 if (model_context == nullptr) {
85 err_msg << "Invalid null context.";
86 MS_LOG(ERROR) << err_msg.str();
87 return Status(kLiteNullptr, err_msg.str());
88 }
89 impl_->SetContext(model_context);
90 impl_->SetGraph(graph.GetGraph());
91 impl_->SetConfig(train_cfg);
92 return impl_->Build();
93 }
94
Resize(const std::vector<MSTensor> & inputs,const std::vector<std::vector<int64_t>> & dims)95 Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
96 if (impl_ == nullptr) {
97 MS_LOG(ERROR) << "Model implement is null.";
98 return kLiteNullptr;
99 }
100 return impl_->Resize(inputs, dims);
101 }
102
Predict(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs,const MSKernelCallBack & before,const MSKernelCallBack & after)103 Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
104 const MSKernelCallBack &before, const MSKernelCallBack &after) {
105 if (impl_ == nullptr) {
106 MS_LOG(ERROR) << "Model implement is null.";
107 return kLiteNullptr;
108 }
109 return impl_->Predict(inputs, outputs, before, after);
110 }
111
PredictWithPreprocess(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs,const MSKernelCallBack & before,const MSKernelCallBack & after)112 Status Model::PredictWithPreprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
113 const MSKernelCallBack &before, const MSKernelCallBack &after) {
114 MS_LOG(ERROR) << "Unsupported Feature.";
115 return kLiteNotSupport;
116 }
117
Preprocess(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs)118 Status Model::Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
119 MS_LOG(ERROR) << "Unsupported Feature.";
120 return kLiteNotSupport;
121 }
122
HasPreprocess()123 bool Model::HasPreprocess() {
124 MS_LOG(ERROR) << "Unsupported Feature.";
125 return false;
126 }
127
Model()128 Model::Model() : impl_(nullptr) {}
129
~Model()130 Model::~Model() {}
131
CheckModelSupport(enum DeviceType device_type,ModelType model_type)132 bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
133 MS_LOG(ERROR) << "Unsupported feature.";
134 return false;
135 }
136
GetInputs()137 std::vector<MSTensor> Model::GetInputs() {
138 std::vector<MSTensor> empty;
139 if (impl_ == nullptr) {
140 MS_LOG(ERROR) << "Model implement is null.";
141 return empty;
142 }
143 return impl_->GetInputs();
144 }
145
GetOutputs()146 std::vector<MSTensor> Model::GetOutputs() {
147 std::vector<MSTensor> empty;
148 if (impl_ == nullptr) {
149 MS_LOG(ERROR) << "Model implement is null.";
150 return empty;
151 }
152 return impl_->GetOutputs();
153 }
154
GetInputByTensorName(const std::vector<char> & name)155 MSTensor Model::GetInputByTensorName(const std::vector<char> &name) {
156 if (impl_ == nullptr) {
157 MS_LOG(ERROR) << "Model implement is null.";
158 return MSTensor(nullptr);
159 }
160 return impl_->GetInputByTensorName(CharToString(name));
161 }
162
GetOutputTensorNamesChar()163 std::vector<std::vector<char>> Model::GetOutputTensorNamesChar() {
164 if (impl_ == nullptr) {
165 MS_LOG(ERROR) << "Model implement is null.";
166 std::vector<std::vector<char>> empty;
167 return empty;
168 }
169 return VectorStringToChar(impl_->GetOutputTensorNames());
170 }
171
GetOutputByTensorName(const std::vector<char> & name)172 MSTensor Model::GetOutputByTensorName(const std::vector<char> &name) {
173 if (impl_ == nullptr) {
174 MS_LOG(ERROR) << "Model implement is null.";
175 return MSTensor(nullptr);
176 }
177 return impl_->GetOutputByTensorName(CharToString(name));
178 }
179
GetOutputsByNodeName(const std::vector<char> & node_name)180 std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_name) {
181 if (impl_ == nullptr) {
182 MS_LOG(ERROR) << "Model implement is null.";
183 std::vector<MSTensor> empty;
184 return empty;
185 }
186 return impl_->GetOutputsByNodeName(CharToString(node_name));
187 }
188
LoadConfig(const std::string & config_path)189 Status Model::LoadConfig(const std::string &config_path) {
190 std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
191 if (impl_ != nullptr) {
192 MS_LOG(ERROR) << "impl_ illegal in LoadConfig.";
193 return Status(kLiteFileError, "Illegal operation.");
194 }
195
196 impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
197 if (impl_ == nullptr) {
198 MS_LOG(ERROR) << "Model implement is null.";
199 return Status(kLiteFileError, "Fail to load config file.");
200 }
201
202 auto ret = impl_->LoadConfig(config_path);
203 if (ret != kSuccess) {
204 MS_LOG(ERROR) << "impl_ LoadConfig failed,";
205 return Status(kLiteFileError, "Invalid config file.");
206 }
207 return kSuccess;
208 }
209
SetTrainMode(bool train)210 Status Model::SetTrainMode(bool train) {
211 if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
212 MS_LOG(ERROR) << "Model is null.";
213 return kLiteUninitializedObj;
214 }
215 auto ret = (train) ? impl_->session_->Train() : impl_->session_->Eval();
216 return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
217 }
218
GetTrainMode() const219 bool Model::GetTrainMode() const { return ((impl_ != nullptr) && (impl_->session_) && (impl_->session_->IsTrain())); }
220
GetGradients() const221 std::vector<MSTensor> Model::GetGradients() const {
222 std::vector<MSTensor> empty;
223 if (impl_ == nullptr) {
224 MS_LOG(ERROR) << "Model implement is null.";
225 return empty;
226 }
227 return impl_->GetGradients();
228 }
229
ApplyGradients(const std::vector<MSTensor> & gradients)230 Status Model::ApplyGradients(const std::vector<MSTensor> &gradients) {
231 if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
232 MS_LOG(ERROR) << "Model is null.";
233 return kLiteUninitializedObj;
234 }
235 return impl_->ApplyGradients(gradients);
236 }
237
GetOptimizerParams() const238 std::vector<MSTensor> Model::GetOptimizerParams() const {
239 std::vector<MSTensor> empty;
240 if (impl_ == nullptr) {
241 MS_LOG(ERROR) << "Model implement is null.";
242 return empty;
243 }
244 auto res = impl_->GetOptimizerParams();
245 return res;
246 }
247
SetOptimizerParams(const std::vector<MSTensor> & params)248 Status Model::SetOptimizerParams(const std::vector<MSTensor> ¶ms) {
249 if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
250 MS_LOG(ERROR) << "Model is null.";
251 return kLiteUninitializedObj;
252 }
253 return impl_->SetOptimizerParams(params);
254 }
255
InitMetrics(std::vector<Metrics * > metrics)256 Status Model::InitMetrics(std::vector<Metrics *> metrics) {
257 if (impl_ == nullptr) {
258 MS_LOG(ERROR) << "Model implement is null.";
259 return kLiteUninitializedObj;
260 }
261 return impl_->InitMetrics(metrics);
262 }
263
GetMetrics()264 std::vector<Metrics *> Model::GetMetrics() {
265 std::vector<Metrics *> empty;
266 if (impl_ == nullptr) {
267 MS_LOG(ERROR) << "Model implement is null.";
268 return empty;
269 }
270 return impl_->GetMetrics();
271 }
272
273 } // namespace mindspore
274