• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/cxx_api/model/model_impl.h"
18 #include <memory>
19 #include <algorithm>
20 #include "include/api/types.h"
21 #include "include/api/context.h"
22 #include "include/lite_session.h"
23 #include "include/context.h"
24 #include "src/runtime/inner_allocator.h"
25 #include "src/cxx_api/converters.h"
26 #include "src/cxx_api/graph/graph_data.h"
27 #include "src/cxx_api/tensor/tensor_impl.h"
28 #include "src/cxx_api/tensor_utils.h"
29 #include "src/common/log_adapter.h"
30 #include "src/lite_session.h"
31 #include "src/common/file_utils.h"
32 #include "src/common/config_file.h"
33 
34 namespace mindspore {
35 using mindspore::lite::RET_ERROR;
36 using mindspore::lite::RET_OK;
37 
CreateTrainSessionCallbackHolder(CreateTrainSessionProto * proto)38 CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto) {
39   static CreateTrainSessionProto *proto_ = nullptr;
40   if (proto != nullptr) {
41     proto_ = proto;
42   }
43   return proto_;
44 }
45 
Build(const void * model_data,size_t data_size,ModelType model_type,const std::shared_ptr<Context> & ms_context)46 Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
47                         const std::shared_ptr<Context> &ms_context) {
48   if (model_data == nullptr) {
49     MS_LOG(ERROR) << "The input model buffer is nullptr.";
50     return kLiteNullptr;
51   }
52   if (data_size == 0) {
53     MS_LOG(ERROR) << "The input model buffer size is 0.";
54     return kLiteInputParamInvalid;
55   }
56   context_ = ms_context;
57   auto session = std::shared_ptr<lite::LiteSession>(CreateLiteSession(ContextUtils::Convert(ms_context.get())));
58   if (session == nullptr) {
59     MS_LOG(ERROR) << "Allocate session failed.";
60     return kLiteNullptr;
61   }
62 
63   auto ret = session->LoadModelAndCompileByBuf(static_cast<const char *>(model_data), data_size);
64   if (ret != RET_OK) {
65     MS_LOG(ERROR) << "Init session failed";
66     return kLiteError;
67   }
68 
69   session_.swap(session);
70   MS_LOG(DEBUG) << "Build model success.";
71   return kSuccess;
72 }
73 
Build(const std::string & model_path,ModelType model_type,const std::shared_ptr<Context> & ms_context)74 Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
75                         const std::shared_ptr<Context> &ms_context) {
76   auto session = std::shared_ptr<lite::LiteSession>(CreateLiteSession(ContextUtils::Convert(ms_context.get())));
77   if (session == nullptr) {
78     MS_LOG(ERROR) << "Allocate session failed.";
79     return kLiteNullptr;
80   }
81 
82   auto ret = session->LoadModelAndCompileByPath(model_path);
83   if (ret != RET_OK) {
84     MS_LOG(ERROR) << "Init session failed";
85     return kLiteError;
86   }
87 
88   session_.swap(session);
89   MS_LOG(DEBUG) << "Build model success.";
90   return kSuccess;
91 }
92 
Build()93 Status ModelImpl::Build() {
94   MS_LOG(DEBUG) << "Start build model.";
95   if (graph_ == nullptr || graph_->graph_data_ == nullptr) {
96     MS_LOG(ERROR) << "Invalid graph.";
97     return kLiteNullptr;
98   }
99 
100   if (context_ == nullptr) {
101     MS_LOG(ERROR) << "Invalid context.";
102     return kLiteNullptr;
103   }
104 
105   auto *inner_context = ContextUtils::Convert(context_.get());
106   if (inner_context == nullptr) {
107     MS_LOG(ERROR) << "Failed to convert Context to Lite Context";
108     return kLiteNullptr;
109   }
110 
111   auto create_callback = CreateTrainSessionCallbackHolder();
112   if (create_callback != nullptr) {
113     auto session = create_callback(graph_->graph_data_, cfg_, inner_context);
114     if (session != nullptr) {
115       session_ = session;
116       MS_LOG(DEBUG) << "Build model success.";
117       return kSuccess;
118     }
119   }
120 
121   auto model = graph_->graph_data_->lite_model();
122   if (model == nullptr || model->buf == nullptr) {
123     delete inner_context;
124     MS_LOG(ERROR) << "Lite model has been freed.";
125     return kLiteError;
126   }
127 
128   auto session = std::shared_ptr<lite::LiteSession>(CreateLiteSession(inner_context));
129   if (session == nullptr) {
130     MS_LOG(ERROR) << "Allocate session failed.";
131     return kLiteNullptr;
132   }
133   auto ret = session->CompileGraph(model.get());
134   if (ret != RET_OK) {
135     MS_LOG(ERROR) << "Build model failed.";
136     return static_cast<StatusCode>(ret);
137   }
138   session_.swap(session);
139   model->Free();
140   MS_LOG(DEBUG) << "Build model success.";
141   return kSuccess;
142 }
143 
ResetTensorData(std::vector<void * > old_data,std::vector<tensor::MSTensor * > tensors)144 static void ResetTensorData(std::vector<void *> old_data, std::vector<tensor::MSTensor *> tensors) {
145   for (size_t j = 0; j < old_data.size(); j++) {
146     tensors.at(j)->set_data(old_data.at(j));
147   }
148 }
149 
RunGraph(const MSKernelCallBack & before,const MSKernelCallBack & after)150 Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after) {
151   if (before == nullptr || after == nullptr) {
152     auto ret = session_->RunGraph();
153     return static_cast<StatusCode>(ret);
154   }
155   auto before_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
156                               const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
157                               const CallBackParam &call_param) {
158     std::vector<MSTensor> inputs = LiteTensorsToMSTensors(before_inputs);
159     std::vector<MSTensor> outputs = LiteTensorsToMSTensors(before_outputs);
160     MSCallBackParam mscall_param;
161     mscall_param.node_name = call_param.node_name;
162     mscall_param.node_type = call_param.node_type;
163     return before(inputs, outputs, mscall_param);
164   };
165 
166   auto after_call_back = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
167                              const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
168                              const CallBackParam &call_param) {
169     std::vector<MSTensor> inputs = LiteTensorsToMSTensors(before_inputs);
170     std::vector<MSTensor> outputs = LiteTensorsToMSTensors(before_outputs);
171     MSCallBackParam mscall_param;
172     mscall_param.node_name = call_param.node_name;
173     mscall_param.node_type = call_param.node_type;
174     return after(inputs, outputs, mscall_param);
175   };
176   auto ret = session_->RunGraph(before_call_back, after_call_back);
177   return static_cast<StatusCode>(ret);
178 }
179 
IsTrainModel()180 bool ModelImpl::IsTrainModel() { return (graph_ && graph_->graph_data_ && graph_->graph_data_->IsTrainModel()); }
181 
LoadConfig(const std::string & config_path)182 Status ModelImpl::LoadConfig(const std::string &config_path) {
183   std::map<std::string, std::string> config_info;
184   int ret = lite::GetSectionInfoFromConfigFile(config_path, CONFIG_FILE_EXECUTION_PLAN, &config_info);
185   if (ret != RET_OK) {
186     MS_LOG(ERROR) << "GetSectionInfoFromConfigFile failed.";
187     return kLiteFileError;
188   }
189 
190   if (config_info.empty()) {
191     MS_LOG(WARNING) << "No valid info in config file.";
192     return kSuccess;
193   }
194 
195   lite::ParserExecutionPlan(&config_info, &execution_plan_);
196   return kSuccess;
197 }
198 
Predict(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs,const MSKernelCallBack & before,const MSKernelCallBack & after)199 Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
200                           const MSKernelCallBack &before, const MSKernelCallBack &after) {
201   if (outputs == nullptr) {
202     MS_LOG(ERROR) << "outputs is nullptr.";
203     return kLiteError;
204   }
205   if (session_ == nullptr) {
206     MS_LOG(ERROR) << "Run graph failed.";
207     return kLiteError;
208   }
209   auto input_tensors = session_->GetInputs();
210   if (input_tensors.empty()) {
211     MS_LOG(ERROR) << "Failed to get input tensor.";
212     return kLiteError;
213   }
214   if (input_tensors.size() != inputs.size()) {
215     MS_LOG(ERROR) << "Wrong input size.";
216     return kLiteError;
217   }
218   std::vector<void *> old_data;
219   for (size_t i = 0; i < inputs.size(); i++) {
220     auto input = input_tensors.at(i);
221     auto user_input = inputs.at(i);
222     if (user_input.DataType() != static_cast<enum DataType>(input->data_type())) {
223       ResetTensorData(old_data, input_tensors);
224       MS_LOG(ERROR) << "Tensor " << user_input.Name() << " has a different data type from input" << input->tensor_name()
225                     << ".";
226       return kLiteInputTensorError;
227     }
228     if (user_input.Data() == nullptr) {
229       ResetTensorData(old_data, input_tensors);
230       MS_LOG(ERROR) << "Tensor " << user_input.Name() << " has no data.";
231       return kLiteInputTensorError;
232     }
233     if (user_input.Name() != input->tensor_name()) {
234       MS_LOG(WARNING) << "Tensor " << user_input.Name() << " has a different name from input" << input->tensor_name()
235                       << ".";
236     }
237     old_data.push_back(input->data());
238     if (input->data_type() == kObjectTypeString) {
239 #ifndef STRING_KERNEL_CLIP
240       std::vector<int32_t> shape = TruncateShape(user_input.Shape(), input->data_type(), user_input.DataSize(), false);
241       if (shape.empty() && !(user_input.Shape().empty())) {
242         ResetTensorData(old_data, input_tensors);
243         MS_LOG(ERROR) << "Input dims of tensor " << user_input.Name() << " is invalid.";
244         return kLiteParamInvalid;
245       }
246       input->set_shape(shape);
247       input->set_data(user_input.MutableData());
248 #else
249       MS_LOG(ERROR) << unsupport_string_tensor_log;
250       return kLiteError;
251 #endif
252     } else {
253       if (user_input.MutableData() != input->data()) {
254         if (input->Size() != user_input.DataSize()) {
255           ResetTensorData(old_data, input_tensors);
256           MS_LOG(ERROR) << "Tensor " << user_input.Name() << " has wrong data size.";
257           return kLiteInputTensorError;
258         }
259         input->set_data(user_input.MutableData());
260       }
261     }
262   }
263   auto ret = RunGraph(before, after);
264   ResetTensorData(old_data, input_tensors);
265   if (ret != kSuccess) {
266     MS_LOG(ERROR) << "Run graph failed.";
267     return ret;
268   }
269   MS_LOG(DEBUG) << "Run graph success.";
270   auto res = GetOutputs();
271   if (res.empty()) {
272     MS_LOG(DEBUG) << "Empty outputs.";
273     return kLiteError;
274   }
275   outputs->clear();
276   outputs->insert(outputs->end(), res.begin(), res.end());
277   return kSuccess;
278 }
279 
GetInputs()280 std::vector<MSTensor> ModelImpl::GetInputs() {
281   std::vector<MSTensor> empty;
282   if (session_ == nullptr) {
283     MS_LOG(ERROR) << "Session is null.";
284     return empty;
285   }
286   std::vector<MSTensor> res;
287   auto inputs = session_->GetInputs();
288   if (inputs.empty()) {
289     MS_LOG(ERROR) << "The inputs of model is null.";
290     return empty;
291   }
292   res.resize(inputs.size());
293   for (size_t i = 0; i < inputs.size(); i++) {
294     auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(inputs[i]));
295     if (impl == nullptr || impl->lite_tensor() == nullptr) {
296       MS_LOG(ERROR) << "Create tensor failed.";
297       return empty;
298     }
299     auto tensor = MSTensor(impl);
300     if (tensor == nullptr) {
301       MS_LOG(ERROR) << "Create tensor failed.";
302       return empty;
303     }
304     res[i] = tensor;
305   }
306   return res;
307 }
308 
GetOutputs()309 std::vector<MSTensor> ModelImpl::GetOutputs() {
310   std::vector<MSTensor> empty;
311   if (session_ == nullptr) {
312     MS_LOG(ERROR) << "Session is null.";
313     return empty;
314   }
315   std::vector<MSTensor> res;
316   auto names = session_->GetOutputTensorNames();
317   if (names.empty()) {
318     MS_LOG(ERROR) << "The output tensor name of this model is null.";
319     return empty;
320   }
321   auto outputs = session_->GetOutputs();
322   if (outputs.empty()) {
323     MS_LOG(ERROR) << "The outputs of model is null.";
324     return empty;
325   }
326   if (names.size() != outputs.size()) {
327     MS_LOG(ERROR) << "The size of outputs dose not match the size of names.";
328     return empty;
329   }
330   res.resize(names.size());
331   for (size_t i = 0; i < names.size(); i++) {
332     auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(outputs[names[i]]));
333     if (impl == nullptr || impl->lite_tensor() == nullptr) {
334       MS_LOG(ERROR) << "Create tensor failed.";
335       return empty;
336     }
337     auto tensor = MSTensor(impl);
338     if (tensor == nullptr) {
339       MS_LOG(ERROR) << "Create tensor failed.";
340       return empty;
341     }
342     res[i] = tensor;
343   }
344   return res;
345 }
346 
GetGradients() const347 std::vector<MSTensor> ModelImpl::GetGradients() const {
348   std::vector<MSTensor> empty;
349   if (session_ == nullptr) {
350     MS_LOG(ERROR) << "Session is null.";
351     return empty;
352   }
353   auto params = session_->GetGradients();
354   if (params.empty()) {
355     MS_LOG(ERROR) << "No optimizer parameters avelibale.";
356     return empty;
357   }
358   std::vector<MSTensor> res = LiteTensorsToMSTensors(params, false);
359   return res;
360 }
361 
ApplyGradients(const std::vector<MSTensor> & gradients)362 Status ModelImpl::ApplyGradients(const std::vector<MSTensor> &gradients) {
363   if (session_ == nullptr) {
364     MS_LOG(ERROR) << "Session is null.";
365     return kLiteNullptr;
366   }
367   if (gradients.empty()) {
368     MS_LOG(ERROR) << "gradients is null.";
369     return kLiteInputParamInvalid;
370   }
371   std::vector<tensor::MSTensor *> inner_gradients;
372   inner_gradients.resize(gradients.size());
373   for (size_t i = 0; i < gradients.size(); i++) {
374     auto gradient = gradients[i];
375     if (gradient.impl_ == nullptr || gradient.impl_->lite_tensor() == nullptr) {
376       MS_LOG(ERROR) << "gradient tensor " << gradient.Name() << " is null.";
377       return kLiteInputTensorError;
378     }
379     inner_gradients[i] = gradient.impl_->lite_tensor();
380   }
381   auto ret = session_->ApplyGradients(inner_gradients);
382   return static_cast<StatusCode>(ret);
383 }
384 
GetOptimizerParams() const385 std::vector<MSTensor> ModelImpl::GetOptimizerParams() const {
386   std::vector<MSTensor> empty;
387   if (session_ == nullptr) {
388     MS_LOG(ERROR) << "Session is null.";
389     return empty;
390   }
391   auto params = session_->GetOptimizerParams();
392   if (params.empty()) {
393     MS_LOG(ERROR) << "No optimizer parameters avelibale.";
394     return empty;
395   }
396   std::vector<MSTensor> res = LiteTensorsToMSTensors(params);
397   return res;
398 }
399 
SetOptimizerParams(const std::vector<MSTensor> & params)400 Status ModelImpl::SetOptimizerParams(const std::vector<MSTensor> &params) {
401   if (session_ == nullptr) {
402     MS_LOG(ERROR) << "Session is null.";
403     return kLiteNullptr;
404   }
405   if (params.empty()) {
406     MS_LOG(ERROR) << "params is null.";
407     return kLiteInputParamInvalid;
408   }
409   std::vector<tensor::MSTensor *> inner_params;
410   inner_params.resize(params.size());
411   for (size_t i = 0; i < params.size(); i++) {
412     auto param = params[i];
413     if (param.impl_ == nullptr || param.impl_->lite_tensor() == nullptr) {
414       MS_LOG(ERROR) << "Param tensor " << param.Name() << " is null.";
415       return kLiteInputTensorError;
416     }
417     inner_params[i] = param.impl_->lite_tensor();
418   }
419   auto ret = session_->SetOptimizerParams(inner_params);
420   return static_cast<StatusCode>(ret);
421 }
422 
GetInputByTensorName(const std::string & name)423 MSTensor ModelImpl::GetInputByTensorName(const std::string &name) {
424   if (session_ == nullptr) {
425     MS_LOG(ERROR) << "Session is null.";
426     return MSTensor(nullptr);
427   }
428   auto res = session_->GetInputsByTensorName(name);
429   if (res == nullptr) {
430     MS_LOG(ERROR) << "Model does not contains tensor " << name << " .";
431     return MSTensor(nullptr);
432   }
433   auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(res));
434   if (impl == nullptr || impl->lite_tensor() == nullptr) {
435     MS_LOG(ERROR) << "Create tensor failed.";
436     return MSTensor(nullptr);
437   }
438 
439   return MSTensor(impl);
440 }
441 
GetOutputTensorNames()442 std::vector<std::string> ModelImpl::GetOutputTensorNames() {
443   if (session_ == nullptr) {
444     MS_LOG(ERROR) << "Session is null.";
445     std::vector<std::string> empty;
446     return empty;
447   }
448   return session_->GetOutputTensorNames();
449 }
450 
GetOutputByTensorName(const std::string & name)451 MSTensor ModelImpl::GetOutputByTensorName(const std::string &name) {
452   if (session_ == nullptr) {
453     MS_LOG(ERROR) << "Session is null.";
454     return MSTensor(nullptr);
455   }
456   auto res = session_->GetOutputByTensorName(name);
457   if (res == nullptr) {
458     MS_LOG(ERROR) << "Model does not contains tensor " << name << " .";
459     return MSTensor(nullptr);
460   }
461   auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(res));
462   if (impl == nullptr || impl->lite_tensor() == nullptr) {
463     MS_LOG(ERROR) << "Create tensor failed.";
464     return MSTensor(nullptr);
465   }
466 
467   return MSTensor(impl);
468 }
469 
GetOutputsByNodeName(const std::string & name)470 std::vector<MSTensor> ModelImpl::GetOutputsByNodeName(const std::string &name) {
471   std::vector<MSTensor> empty;
472   if (session_ == nullptr) {
473     MS_LOG(ERROR) << "Session is null.";
474     return empty;
475   }
476   std::vector<MSTensor> res;
477   auto outputs = session_->GetOutputsByNodeName(name);
478   if (outputs.empty()) {
479     MS_LOG(ERROR) << "The outputs of model is null.";
480     return empty;
481   }
482   res.resize(outputs.size());
483   for (size_t i = 0; i < outputs.size(); i++) {
484     auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(outputs[i]));
485     if (impl == nullptr || impl->lite_tensor() == nullptr) {
486       MS_LOG(ERROR) << "Create tensor failed.";
487       return empty;
488     }
489     auto tensor = MSTensor(impl);
490     if (tensor == nullptr) {
491       MS_LOG(ERROR) << "Create tensor failed.";
492       return empty;
493     }
494     res[i] = tensor;
495   }
496   return res;
497 }
498 
Resize(const std::vector<MSTensor> & inputs,const std::vector<std::vector<int64_t>> & dims)499 Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
500   if (session_ == nullptr) {
501     MS_LOG(ERROR) << "Session is null.";
502     return kLiteNullptr;
503   }
504   if (inputs.empty()) {
505     MS_LOG(ERROR) << "Inputs is null.";
506     return kLiteInputParamInvalid;
507   }
508   if (dims.empty()) {
509     MS_LOG(ERROR) << "Dims is null.";
510     return kLiteInputParamInvalid;
511   }
512   if (inputs.size() != dims.size()) {
513     MS_LOG(ERROR) << "The size of inputs does not match the size of dims.";
514     return kLiteInputParamInvalid;
515   }
516   auto model_inputs = session_->GetInputs();
517   if (model_inputs.empty()) {
518     MS_LOG(ERROR) << "The inputs of model is null.";
519     return kLiteParamInvalid;
520   }
521   if (inputs.size() != model_inputs.size()) {
522     MS_LOG(ERROR) << "The size of inputs is incorrect.";
523     return kLiteInputParamInvalid;
524   }
525   std::vector<tensor::MSTensor *> inner_input;
526   inner_input.resize(inputs.size());
527   std::vector<std::vector<int32_t>> truncated_shape;
528   truncated_shape.resize(inputs.size());
529   for (size_t i = 0; i < inputs.size(); i++) {
530     auto input = inputs[i];
531     if (input.impl_ == nullptr || input.impl_->lite_tensor() == nullptr) {
532       MS_LOG(ERROR) << "Input tensor " << input.Name() << " is null.";
533       return kLiteInputTensorError;
534     }
535     inner_input[i] = input.impl_->lite_tensor();
536     std::vector<int32_t> shape = TruncateShape(dims[i], inner_input[i]->data_type(), inner_input[i]->Size(), false);
537     if (shape.empty() && !(dims[i].empty())) {
538       MS_LOG(ERROR) << "Input dims[" << i << "] is invalid.";
539       return kLiteParamInvalid;
540     }
541     truncated_shape[i] = shape;
542   }
543   auto ret = session_->Resize(inner_input, truncated_shape);
544   return static_cast<StatusCode>(ret);
545 }
546 
CreateLiteSession(lite::InnerContext * context)547 lite::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) {
548   auto session = new (std::nothrow) lite::LiteSession();
549   if (session == nullptr) {
550     MS_LOG(ERROR) << "create session failed";
551     delete context;
552     return nullptr;
553   }
554 
555   session->InitExecutionConfig(&execution_plan_);
556 
557   auto ret = session->Init(context);
558   if (ret != mindspore::lite::RET_OK) {
559     MS_LOG(ERROR) << "init session failed";
560     delete session;
561     return nullptr;
562   }
563   return session;
564 }
565 }  // namespace mindspore
566