• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/api/model.h"
18 #include <mutex>
19 #include "include/api/types.h"
20 #include "include/api/context.h"
21 #include "include/api/callback/callback.h"
22 #include "include/api/dual_abi_helper.h"
23 #include "src/cxx_api/model/model_impl.h"
24 #include "src/cxx_api/callback/callback_impl.h"
25 #include "src/cxx_api/callback/callback_adapter.h"
26 #include "src/common/log_adapter.h"
27 
28 namespace mindspore {
29 std::mutex g_impl_init_lock;
30 
Build(const void * model_data,size_t data_size,ModelType model_type,const std::shared_ptr<Context> & model_context,const Key & dec_key,const std::string & dec_mode)31 Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
32                     const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
33   if (impl_ == nullptr) {
34     std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
35     impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
36     if (impl_ == nullptr) {
37       MS_LOG(ERROR) << "Model implement is null.";
38       return kLiteFileError;
39     }
40   }
41 
42   Status ret = impl_->Build(model_data, data_size, model_type, model_context);
43   if (ret != kSuccess) {
44     return ret;
45   }
46   return kSuccess;
47 }
48 
Build(const std::string & model_path,ModelType model_type,const std::shared_ptr<Context> & model_context,const Key & dec_key,const std::string & dec_mode)49 Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
50                     const Key &dec_key, const std::string &dec_mode) {
51   if (impl_ == nullptr) {
52     std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
53     impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
54     if (impl_ == nullptr) {
55       MS_LOG(ERROR) << "Model implement is null.";
56       return kLiteFileError;
57     }
58   }
59 
60   Status ret = impl_->Build(model_path, model_type, model_context);
61   if (ret != kSuccess) {
62     return ret;
63   }
64   return kSuccess;
65 }
66 
Build(GraphCell graph,const std::shared_ptr<Context> & model_context,const std::shared_ptr<TrainCfg> & train_cfg)67 Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_context,
68                     const std::shared_ptr<TrainCfg> &train_cfg) {
69   std::stringstream err_msg;
70   if (impl_ == nullptr) {
71     std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
72     impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
73     if (impl_ == nullptr) {
74       MS_LOG(ERROR) << "Model implement is null.";
75       return kLiteFileError;
76     }
77   }
78 
79   if (graph.GetGraph() == nullptr) {
80     err_msg << "Invalid null graph.";
81     MS_LOG(ERROR) << err_msg.str();
82     return Status(kLiteNullptr, err_msg.str());
83   }
84   if (model_context == nullptr) {
85     err_msg << "Invalid null context.";
86     MS_LOG(ERROR) << err_msg.str();
87     return Status(kLiteNullptr, err_msg.str());
88   }
89   impl_->SetContext(model_context);
90   impl_->SetGraph(graph.GetGraph());
91   impl_->SetConfig(train_cfg);
92   return impl_->Build();
93 }
94 
Resize(const std::vector<MSTensor> & inputs,const std::vector<std::vector<int64_t>> & dims)95 Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
96   if (impl_ == nullptr) {
97     MS_LOG(ERROR) << "Model implement is null.";
98     return kLiteNullptr;
99   }
100   return impl_->Resize(inputs, dims);
101 }
102 
Predict(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs,const MSKernelCallBack & before,const MSKernelCallBack & after)103 Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
104                       const MSKernelCallBack &before, const MSKernelCallBack &after) {
105   if (impl_ == nullptr) {
106     MS_LOG(ERROR) << "Model implement is null.";
107     return kLiteNullptr;
108   }
109   return impl_->Predict(inputs, outputs, before, after);
110 }
111 
PredictWithPreprocess(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs,const MSKernelCallBack & before,const MSKernelCallBack & after)112 Status Model::PredictWithPreprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
113                                     const MSKernelCallBack &before, const MSKernelCallBack &after) {
114   MS_LOG(ERROR) << "Unsupported Feature.";
115   return kLiteNotSupport;
116 }
117 
Preprocess(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs)118 Status Model::Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
119   MS_LOG(ERROR) << "Unsupported Feature.";
120   return kLiteNotSupport;
121 }
122 
HasPreprocess()123 bool Model::HasPreprocess() {
124   MS_LOG(ERROR) << "Unsupported Feature.";
125   return false;
126 }
127 
Model()128 Model::Model() : impl_(nullptr) {}
129 
~Model()130 Model::~Model() {}
131 
CheckModelSupport(enum DeviceType device_type,ModelType model_type)132 bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
133   MS_LOG(ERROR) << "Unsupported feature.";
134   return false;
135 }
136 
GetInputs()137 std::vector<MSTensor> Model::GetInputs() {
138   std::vector<MSTensor> empty;
139   if (impl_ == nullptr) {
140     MS_LOG(ERROR) << "Model implement is null.";
141     return empty;
142   }
143   return impl_->GetInputs();
144 }
145 
GetOutputs()146 std::vector<MSTensor> Model::GetOutputs() {
147   std::vector<MSTensor> empty;
148   if (impl_ == nullptr) {
149     MS_LOG(ERROR) << "Model implement is null.";
150     return empty;
151   }
152   return impl_->GetOutputs();
153 }
154 
GetInputByTensorName(const std::vector<char> & name)155 MSTensor Model::GetInputByTensorName(const std::vector<char> &name) {
156   if (impl_ == nullptr) {
157     MS_LOG(ERROR) << "Model implement is null.";
158     return MSTensor(nullptr);
159   }
160   return impl_->GetInputByTensorName(CharToString(name));
161 }
162 
GetOutputTensorNamesChar()163 std::vector<std::vector<char>> Model::GetOutputTensorNamesChar() {
164   if (impl_ == nullptr) {
165     MS_LOG(ERROR) << "Model implement is null.";
166     std::vector<std::vector<char>> empty;
167     return empty;
168   }
169   return VectorStringToChar(impl_->GetOutputTensorNames());
170 }
171 
GetOutputByTensorName(const std::vector<char> & name)172 MSTensor Model::GetOutputByTensorName(const std::vector<char> &name) {
173   if (impl_ == nullptr) {
174     MS_LOG(ERROR) << "Model implement is null.";
175     return MSTensor(nullptr);
176   }
177   return impl_->GetOutputByTensorName(CharToString(name));
178 }
179 
GetOutputsByNodeName(const std::vector<char> & node_name)180 std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_name) {
181   if (impl_ == nullptr) {
182     MS_LOG(ERROR) << "Model implement is null.";
183     std::vector<MSTensor> empty;
184     return empty;
185   }
186   return impl_->GetOutputsByNodeName(CharToString(node_name));
187 }
188 
LoadConfig(const std::string & config_path)189 Status Model::LoadConfig(const std::string &config_path) {
190   std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
191   if (impl_ != nullptr) {
192     MS_LOG(ERROR) << "impl_ illegal in LoadConfig.";
193     return Status(kLiteFileError, "Illegal operation.");
194   }
195 
196   impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
197   if (impl_ == nullptr) {
198     MS_LOG(ERROR) << "Model implement is null.";
199     return Status(kLiteFileError, "Fail to load config file.");
200   }
201 
202   auto ret = impl_->LoadConfig(config_path);
203   if (ret != kSuccess) {
204     MS_LOG(ERROR) << "impl_ LoadConfig failed,";
205     return Status(kLiteFileError, "Invalid config file.");
206   }
207   return kSuccess;
208 }
209 
SetTrainMode(bool train)210 Status Model::SetTrainMode(bool train) {
211   if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
212     MS_LOG(ERROR) << "Model is null.";
213     return kLiteUninitializedObj;
214   }
215   auto ret = (train) ? impl_->session_->Train() : impl_->session_->Eval();
216   return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
217 }
218 
GetTrainMode() const219 bool Model::GetTrainMode() const { return ((impl_ != nullptr) && (impl_->session_) && (impl_->session_->IsTrain())); }
220 
GetGradients() const221 std::vector<MSTensor> Model::GetGradients() const {
222   std::vector<MSTensor> empty;
223   if (impl_ == nullptr) {
224     MS_LOG(ERROR) << "Model implement is null.";
225     return empty;
226   }
227   return impl_->GetGradients();
228 }
229 
ApplyGradients(const std::vector<MSTensor> & gradients)230 Status Model::ApplyGradients(const std::vector<MSTensor> &gradients) {
231   if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
232     MS_LOG(ERROR) << "Model is null.";
233     return kLiteUninitializedObj;
234   }
235   return impl_->ApplyGradients(gradients);
236 }
237 
GetOptimizerParams() const238 std::vector<MSTensor> Model::GetOptimizerParams() const {
239   std::vector<MSTensor> empty;
240   if (impl_ == nullptr) {
241     MS_LOG(ERROR) << "Model implement is null.";
242     return empty;
243   }
244   auto res = impl_->GetOptimizerParams();
245   return res;
246 }
247 
SetOptimizerParams(const std::vector<MSTensor> & params)248 Status Model::SetOptimizerParams(const std::vector<MSTensor> &params) {
249   if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
250     MS_LOG(ERROR) << "Model is null.";
251     return kLiteUninitializedObj;
252   }
253   return impl_->SetOptimizerParams(params);
254 }
255 
InitMetrics(std::vector<Metrics * > metrics)256 Status Model::InitMetrics(std::vector<Metrics *> metrics) {
257   if (impl_ == nullptr) {
258     MS_LOG(ERROR) << "Model implement is null.";
259     return kLiteUninitializedObj;
260   }
261   return impl_->InitMetrics(metrics);
262 }
263 
GetMetrics()264 std::vector<Metrics *> Model::GetMetrics() {
265   std::vector<Metrics *> empty;
266   if (impl_ == nullptr) {
267     MS_LOG(ERROR) << "Model implement is null.";
268     return empty;
269   }
270   return impl_->GetMetrics();
271 }
272 
273 }  // namespace mindspore
274