1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/cxx_api/model/model_impl.h"
18 #include <memory>
19 #include <algorithm>
20 #include "include/api/types.h"
21 #include "include/api/context.h"
22 #include "include/lite_session.h"
23 #include "include/context.h"
24 #include "src/runtime/inner_allocator.h"
25 #include "src/cxx_api/converters.h"
26 #include "src/cxx_api/graph/graph_data.h"
27 #include "src/cxx_api/tensor/tensor_impl.h"
28 #include "src/cxx_api/tensor_utils.h"
29 #include "src/common/log_adapter.h"
30 #include "src/lite_session.h"
31 #include "src/common/file_utils.h"
32 #include "src/common/config_file.h"
33
34 namespace mindspore {
35 using mindspore::lite::RET_ERROR;
36 using mindspore::lite::RET_OK;
37
CreateTrainSessionCallbackHolder(CreateTrainSessionProto * proto)38 CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto) {
39 static CreateTrainSessionProto *proto_ = nullptr;
40 if (proto != nullptr) {
41 proto_ = proto;
42 }
43 return proto_;
44 }
45
Build(const void * model_data,size_t data_size,ModelType model_type,const std::shared_ptr<Context> & ms_context)46 Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
47 const std::shared_ptr<Context> &ms_context) {
48 if (model_data == nullptr) {
49 MS_LOG(ERROR) << "The input model buffer is nullptr.";
50 return kLiteNullptr;
51 }
52 if (data_size == 0) {
53 MS_LOG(ERROR) << "The input model buffer size is 0.";
54 return kLiteInputParamInvalid;
55 }
56 context_ = ms_context;
57 auto session = std::shared_ptr<lite::LiteSession>(CreateLiteSession(ContextUtils::Convert(ms_context.get())));
58 if (session == nullptr) {
59 MS_LOG(ERROR) << "Allocate session failed.";
60 return kLiteNullptr;
61 }
62
63 auto ret = session->LoadModelAndCompileByBuf(static_cast<const char *>(model_data), data_size);
64 if (ret != RET_OK) {
65 MS_LOG(ERROR) << "Init session failed";
66 return kLiteError;
67 }
68
69 session_.swap(session);
70 MS_LOG(DEBUG) << "Build model success.";
71 return kSuccess;
72 }
73
Build(const std::string & model_path,ModelType model_type,const std::shared_ptr<Context> & ms_context)74 Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
75 const std::shared_ptr<Context> &ms_context) {
76 auto session = std::shared_ptr<lite::LiteSession>(CreateLiteSession(ContextUtils::Convert(ms_context.get())));
77 if (session == nullptr) {
78 MS_LOG(ERROR) << "Allocate session failed.";
79 return kLiteNullptr;
80 }
81
82 auto ret = session->LoadModelAndCompileByPath(model_path);
83 if (ret != RET_OK) {
84 MS_LOG(ERROR) << "Init session failed";
85 return kLiteError;
86 }
87
88 session_.swap(session);
89 MS_LOG(DEBUG) << "Build model success.";
90 return kSuccess;
91 }
92
Build()93 Status ModelImpl::Build() {
94 MS_LOG(DEBUG) << "Start build model.";
95 if (graph_ == nullptr || graph_->graph_data_ == nullptr) {
96 MS_LOG(ERROR) << "Invalid graph.";
97 return kLiteNullptr;
98 }
99
100 if (context_ == nullptr) {
101 MS_LOG(ERROR) << "Invalid context.";
102 return kLiteNullptr;
103 }
104
105 auto *inner_context = ContextUtils::Convert(context_.get());
106 if (inner_context == nullptr) {
107 MS_LOG(ERROR) << "Failed to convert Context to Lite Context";
108 return kLiteNullptr;
109 }
110
111 auto create_callback = CreateTrainSessionCallbackHolder();
112 if (create_callback != nullptr) {
113 auto session = create_callback(graph_->graph_data_, cfg_, inner_context);
114 if (session != nullptr) {
115 session_ = session;
116 MS_LOG(DEBUG) << "Build model success.";
117 return kSuccess;
118 }
119 }
120
121 auto model = graph_->graph_data_->lite_model();
122 if (model == nullptr || model->buf == nullptr) {
123 delete inner_context;
124 MS_LOG(ERROR) << "Lite model has been freed.";
125 return kLiteError;
126 }
127
128 auto session = std::shared_ptr<lite::LiteSession>(CreateLiteSession(inner_context));
129 if (session == nullptr) {
130 MS_LOG(ERROR) << "Allocate session failed.";
131 return kLiteNullptr;
132 }
133 auto ret = session->CompileGraph(model.get());
134 if (ret != RET_OK) {
135 MS_LOG(ERROR) << "Build model failed.";
136 return static_cast<StatusCode>(ret);
137 }
138 session_.swap(session);
139 model->Free();
140 MS_LOG(DEBUG) << "Build model success.";
141 return kSuccess;
142 }
143
ResetTensorData(std::vector<void * > old_data,std::vector<tensor::MSTensor * > tensors)144 static void ResetTensorData(std::vector<void *> old_data, std::vector<tensor::MSTensor *> tensors) {
145 for (size_t j = 0; j < old_data.size(); j++) {
146 tensors.at(j)->set_data(old_data.at(j));
147 }
148 }
149
RunGraph(const MSKernelCallBack & before,const MSKernelCallBack & after)150 Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after) {
151 if (before == nullptr || after == nullptr) {
152 auto ret = session_->RunGraph();
153 return static_cast<StatusCode>(ret);
154 }
155 auto before_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
156 const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
157 const CallBackParam &call_param) {
158 std::vector<MSTensor> inputs = LiteTensorsToMSTensors(before_inputs);
159 std::vector<MSTensor> outputs = LiteTensorsToMSTensors(before_outputs);
160 MSCallBackParam mscall_param;
161 mscall_param.node_name = call_param.node_name;
162 mscall_param.node_type = call_param.node_type;
163 return before(inputs, outputs, mscall_param);
164 };
165
166 auto after_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
167 const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
168 const CallBackParam &call_param) {
169 std::vector<MSTensor> inputs = LiteTensorsToMSTensors(before_inputs);
170 std::vector<MSTensor> outputs = LiteTensorsToMSTensors(before_outputs);
171 MSCallBackParam mscall_param;
172 mscall_param.node_name = call_param.node_name;
173 mscall_param.node_type = call_param.node_type;
174 return after(inputs, outputs, mscall_param);
175 };
176 auto ret = session_->RunGraph(before_call_back, after_call_back);
177 return static_cast<StatusCode>(ret);
178 }
179
IsTrainModel()180 bool ModelImpl::IsTrainModel() { return (graph_ && graph_->graph_data_ && graph_->graph_data_->IsTrainModel()); }
181
LoadConfig(const std::string & config_path)182 Status ModelImpl::LoadConfig(const std::string &config_path) {
183 std::map<std::string, std::string> config_info;
184 int ret = lite::GetSectionInfoFromConfigFile(config_path, CONFIG_FILE_EXECUTION_PLAN, &config_info);
185 if (ret != RET_OK) {
186 MS_LOG(ERROR) << "GetSectionInfoFromConfigFile failed.";
187 return kLiteFileError;
188 }
189
190 if (config_info.empty()) {
191 MS_LOG(WARNING) << "No valid info in config file.";
192 return kSuccess;
193 }
194
195 lite::ParserExecutionPlan(&config_info, &execution_plan_);
196 return kSuccess;
197 }
198
Predict(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs,const MSKernelCallBack & before,const MSKernelCallBack & after)199 Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
200 const MSKernelCallBack &before, const MSKernelCallBack &after) {
201 if (outputs == nullptr) {
202 MS_LOG(ERROR) << "outputs is nullptr.";
203 return kLiteError;
204 }
205 if (session_ == nullptr) {
206 MS_LOG(ERROR) << "Run graph failed.";
207 return kLiteError;
208 }
209 auto input_tensors = session_->GetInputs();
210 if (input_tensors.empty()) {
211 MS_LOG(ERROR) << "Failed to get input tensor.";
212 return kLiteError;
213 }
214 if (input_tensors.size() != inputs.size()) {
215 MS_LOG(ERROR) << "Wrong input size.";
216 return kLiteError;
217 }
218 std::vector<void *> old_data;
219 for (size_t i = 0; i < inputs.size(); i++) {
220 auto input = input_tensors.at(i);
221 auto user_input = inputs.at(i);
222 if (user_input.DataType() != static_cast<enum DataType>(input->data_type())) {
223 ResetTensorData(old_data, input_tensors);
224 MS_LOG(ERROR) << "Tensor " << user_input.Name() << " has a different data type from input" << input->tensor_name()
225 << ".";
226 return kLiteInputTensorError;
227 }
228 if (user_input.Data() == nullptr) {
229 ResetTensorData(old_data, input_tensors);
230 MS_LOG(ERROR) << "Tensor " << user_input.Name() << " has no data.";
231 return kLiteInputTensorError;
232 }
233 if (user_input.Name() != input->tensor_name()) {
234 MS_LOG(WARNING) << "Tensor " << user_input.Name() << " has a different name from input" << input->tensor_name()
235 << ".";
236 }
237 old_data.push_back(input->data());
238 if (input->data_type() == kObjectTypeString) {
239 #ifndef STRING_KERNEL_CLIP
240 std::vector<int32_t> shape = TruncateShape(user_input.Shape(), input->data_type(), user_input.DataSize(), false);
241 if (shape.empty() && !(user_input.Shape().empty())) {
242 ResetTensorData(old_data, input_tensors);
243 MS_LOG(ERROR) << "Input dims of tensor " << user_input.Name() << " is invalid.";
244 return kLiteParamInvalid;
245 }
246 input->set_shape(shape);
247 input->set_data(user_input.MutableData());
248 #else
249 MS_LOG(ERROR) << unsupport_string_tensor_log;
250 return kLiteError;
251 #endif
252 } else {
253 if (user_input.MutableData() != input->data()) {
254 if (input->Size() != user_input.DataSize()) {
255 ResetTensorData(old_data, input_tensors);
256 MS_LOG(ERROR) << "Tensor " << user_input.Name() << " has wrong data size.";
257 return kLiteInputTensorError;
258 }
259 input->set_data(user_input.MutableData());
260 }
261 }
262 }
263 auto ret = RunGraph(before, after);
264 ResetTensorData(old_data, input_tensors);
265 if (ret != kSuccess) {
266 MS_LOG(ERROR) << "Run graph failed.";
267 return ret;
268 }
269 MS_LOG(DEBUG) << "Run graph success.";
270 auto res = GetOutputs();
271 if (res.empty()) {
272 MS_LOG(DEBUG) << "Empty outputs.";
273 return kLiteError;
274 }
275 outputs->clear();
276 outputs->insert(outputs->end(), res.begin(), res.end());
277 return kSuccess;
278 }
279
GetInputs()280 std::vector<MSTensor> ModelImpl::GetInputs() {
281 std::vector<MSTensor> empty;
282 if (session_ == nullptr) {
283 MS_LOG(ERROR) << "Session is null.";
284 return empty;
285 }
286 std::vector<MSTensor> res;
287 auto inputs = session_->GetInputs();
288 if (inputs.empty()) {
289 MS_LOG(ERROR) << "The inputs of model is null.";
290 return empty;
291 }
292 res.resize(inputs.size());
293 for (size_t i = 0; i < inputs.size(); i++) {
294 auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(inputs[i]));
295 if (impl == nullptr || impl->lite_tensor() == nullptr) {
296 MS_LOG(ERROR) << "Create tensor failed.";
297 return empty;
298 }
299 auto tensor = MSTensor(impl);
300 if (tensor == nullptr) {
301 MS_LOG(ERROR) << "Create tensor failed.";
302 return empty;
303 }
304 res[i] = tensor;
305 }
306 return res;
307 }
308
GetOutputs()309 std::vector<MSTensor> ModelImpl::GetOutputs() {
310 std::vector<MSTensor> empty;
311 if (session_ == nullptr) {
312 MS_LOG(ERROR) << "Session is null.";
313 return empty;
314 }
315 std::vector<MSTensor> res;
316 auto names = session_->GetOutputTensorNames();
317 if (names.empty()) {
318 MS_LOG(ERROR) << "The output tensor name of this model is null.";
319 return empty;
320 }
321 auto outputs = session_->GetOutputs();
322 if (outputs.empty()) {
323 MS_LOG(ERROR) << "The outputs of model is null.";
324 return empty;
325 }
326 if (names.size() != outputs.size()) {
327 MS_LOG(ERROR) << "The size of outputs dose not match the size of names.";
328 return empty;
329 }
330 res.resize(names.size());
331 for (size_t i = 0; i < names.size(); i++) {
332 auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(outputs[names[i]]));
333 if (impl == nullptr || impl->lite_tensor() == nullptr) {
334 MS_LOG(ERROR) << "Create tensor failed.";
335 return empty;
336 }
337 auto tensor = MSTensor(impl);
338 if (tensor == nullptr) {
339 MS_LOG(ERROR) << "Create tensor failed.";
340 return empty;
341 }
342 res[i] = tensor;
343 }
344 return res;
345 }
346
GetGradients() const347 std::vector<MSTensor> ModelImpl::GetGradients() const {
348 std::vector<MSTensor> empty;
349 if (session_ == nullptr) {
350 MS_LOG(ERROR) << "Session is null.";
351 return empty;
352 }
353 auto params = session_->GetGradients();
354 if (params.empty()) {
355 MS_LOG(ERROR) << "No optimizer parameters avelibale.";
356 return empty;
357 }
358 std::vector<MSTensor> res = LiteTensorsToMSTensors(params, false);
359 return res;
360 }
361
ApplyGradients(const std::vector<MSTensor> & gradients)362 Status ModelImpl::ApplyGradients(const std::vector<MSTensor> &gradients) {
363 if (session_ == nullptr) {
364 MS_LOG(ERROR) << "Session is null.";
365 return kLiteNullptr;
366 }
367 if (gradients.empty()) {
368 MS_LOG(ERROR) << "gradients is null.";
369 return kLiteInputParamInvalid;
370 }
371 std::vector<tensor::MSTensor *> inner_gradients;
372 inner_gradients.resize(gradients.size());
373 for (size_t i = 0; i < gradients.size(); i++) {
374 auto gradient = gradients[i];
375 if (gradient.impl_ == nullptr || gradient.impl_->lite_tensor() == nullptr) {
376 MS_LOG(ERROR) << "gradient tensor " << gradient.Name() << " is null.";
377 return kLiteInputTensorError;
378 }
379 inner_gradients[i] = gradient.impl_->lite_tensor();
380 }
381 auto ret = session_->ApplyGradients(inner_gradients);
382 return static_cast<StatusCode>(ret);
383 }
384
GetOptimizerParams() const385 std::vector<MSTensor> ModelImpl::GetOptimizerParams() const {
386 std::vector<MSTensor> empty;
387 if (session_ == nullptr) {
388 MS_LOG(ERROR) << "Session is null.";
389 return empty;
390 }
391 auto params = session_->GetOptimizerParams();
392 if (params.empty()) {
393 MS_LOG(ERROR) << "No optimizer parameters avelibale.";
394 return empty;
395 }
396 std::vector<MSTensor> res = LiteTensorsToMSTensors(params);
397 return res;
398 }
399
SetOptimizerParams(const std::vector<MSTensor> & params)400 Status ModelImpl::SetOptimizerParams(const std::vector<MSTensor> ¶ms) {
401 if (session_ == nullptr) {
402 MS_LOG(ERROR) << "Session is null.";
403 return kLiteNullptr;
404 }
405 if (params.empty()) {
406 MS_LOG(ERROR) << "params is null.";
407 return kLiteInputParamInvalid;
408 }
409 std::vector<tensor::MSTensor *> inner_params;
410 inner_params.resize(params.size());
411 for (size_t i = 0; i < params.size(); i++) {
412 auto param = params[i];
413 if (param.impl_ == nullptr || param.impl_->lite_tensor() == nullptr) {
414 MS_LOG(ERROR) << "Param tensor " << param.Name() << " is null.";
415 return kLiteInputTensorError;
416 }
417 inner_params[i] = param.impl_->lite_tensor();
418 }
419 auto ret = session_->SetOptimizerParams(inner_params);
420 return static_cast<StatusCode>(ret);
421 }
422
GetInputByTensorName(const std::string & name)423 MSTensor ModelImpl::GetInputByTensorName(const std::string &name) {
424 if (session_ == nullptr) {
425 MS_LOG(ERROR) << "Session is null.";
426 return MSTensor(nullptr);
427 }
428 auto res = session_->GetInputsByTensorName(name);
429 if (res == nullptr) {
430 MS_LOG(ERROR) << "Model does not contains tensor " << name << " .";
431 return MSTensor(nullptr);
432 }
433 auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(res));
434 if (impl == nullptr || impl->lite_tensor() == nullptr) {
435 MS_LOG(ERROR) << "Create tensor failed.";
436 return MSTensor(nullptr);
437 }
438
439 return MSTensor(impl);
440 }
441
GetOutputTensorNames()442 std::vector<std::string> ModelImpl::GetOutputTensorNames() {
443 if (session_ == nullptr) {
444 MS_LOG(ERROR) << "Session is null.";
445 std::vector<std::string> empty;
446 return empty;
447 }
448 return session_->GetOutputTensorNames();
449 }
450
GetOutputByTensorName(const std::string & name)451 MSTensor ModelImpl::GetOutputByTensorName(const std::string &name) {
452 if (session_ == nullptr) {
453 MS_LOG(ERROR) << "Session is null.";
454 return MSTensor(nullptr);
455 }
456 auto res = session_->GetOutputByTensorName(name);
457 if (res == nullptr) {
458 MS_LOG(ERROR) << "Model does not contains tensor " << name << " .";
459 return MSTensor(nullptr);
460 }
461 auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(res));
462 if (impl == nullptr || impl->lite_tensor() == nullptr) {
463 MS_LOG(ERROR) << "Create tensor failed.";
464 return MSTensor(nullptr);
465 }
466
467 return MSTensor(impl);
468 }
469
GetOutputsByNodeName(const std::string & name)470 std::vector<MSTensor> ModelImpl::GetOutputsByNodeName(const std::string &name) {
471 std::vector<MSTensor> empty;
472 if (session_ == nullptr) {
473 MS_LOG(ERROR) << "Session is null.";
474 return empty;
475 }
476 std::vector<MSTensor> res;
477 auto outputs = session_->GetOutputsByNodeName(name);
478 if (outputs.empty()) {
479 MS_LOG(ERROR) << "The outputs of model is null.";
480 return empty;
481 }
482 res.resize(outputs.size());
483 for (size_t i = 0; i < outputs.size(); i++) {
484 auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(outputs[i]));
485 if (impl == nullptr || impl->lite_tensor() == nullptr) {
486 MS_LOG(ERROR) << "Create tensor failed.";
487 return empty;
488 }
489 auto tensor = MSTensor(impl);
490 if (tensor == nullptr) {
491 MS_LOG(ERROR) << "Create tensor failed.";
492 return empty;
493 }
494 res[i] = tensor;
495 }
496 return res;
497 }
498
Resize(const std::vector<MSTensor> & inputs,const std::vector<std::vector<int64_t>> & dims)499 Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
500 if (session_ == nullptr) {
501 MS_LOG(ERROR) << "Session is null.";
502 return kLiteNullptr;
503 }
504 if (inputs.empty()) {
505 MS_LOG(ERROR) << "Inputs is null.";
506 return kLiteInputParamInvalid;
507 }
508 if (dims.empty()) {
509 MS_LOG(ERROR) << "Dims is null.";
510 return kLiteInputParamInvalid;
511 }
512 if (inputs.size() != dims.size()) {
513 MS_LOG(ERROR) << "The size of inputs does not match the size of dims.";
514 return kLiteInputParamInvalid;
515 }
516 auto model_inputs = session_->GetInputs();
517 if (model_inputs.empty()) {
518 MS_LOG(ERROR) << "The inputs of model is null.";
519 return kLiteParamInvalid;
520 }
521 if (inputs.size() != model_inputs.size()) {
522 MS_LOG(ERROR) << "The size of inputs is incorrect.";
523 return kLiteInputParamInvalid;
524 }
525 std::vector<tensor::MSTensor *> inner_input;
526 inner_input.resize(inputs.size());
527 std::vector<std::vector<int32_t>> truncated_shape;
528 truncated_shape.resize(inputs.size());
529 for (size_t i = 0; i < inputs.size(); i++) {
530 auto input = inputs[i];
531 if (input.impl_ == nullptr || input.impl_->lite_tensor() == nullptr) {
532 MS_LOG(ERROR) << "Input tensor " << input.Name() << " is null.";
533 return kLiteInputTensorError;
534 }
535 inner_input[i] = input.impl_->lite_tensor();
536 std::vector<int32_t> shape = TruncateShape(dims[i], inner_input[i]->data_type(), inner_input[i]->Size(), false);
537 if (shape.empty() && !(dims[i].empty())) {
538 MS_LOG(ERROR) << "Input dims[" << i << "] is invalid.";
539 return kLiteParamInvalid;
540 }
541 truncated_shape[i] = shape;
542 }
543 auto ret = session_->Resize(inner_input, truncated_shape);
544 return static_cast<StatusCode>(ret);
545 }
546
CreateLiteSession(lite::InnerContext * context)547 lite::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) {
548 auto session = new (std::nothrow) lite::LiteSession();
549 if (session == nullptr) {
550 MS_LOG(ERROR) << "create session failed";
551 delete context;
552 return nullptr;
553 }
554
555 session->InitExecutionConfig(&execution_plan_);
556
557 auto ret = session->Init(context);
558 if (ret != mindspore::lite::RET_OK) {
559 MS_LOG(ERROR) << "init session failed";
560 delete session;
561 return nullptr;
562 }
563 return session;
564 }
565 } // namespace mindspore
566