• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/api/model.h"
17 #include "include/api/context.h"
18 #include "cxx_api/model/model_impl.h"
19 #include "cxx_api/factory.h"
20 #include "utils/utils.h"
21 
22 namespace mindspore {
23 namespace {
GetDeviceTypeString(enum DeviceType type)24 std::string GetDeviceTypeString(enum DeviceType type) {
25   static const std::map<enum DeviceType, std::string> kDeviceTypeStrs = {
26     {kCPU, "CPU"}, {kGPU, "GPU"}, {kKirinNPU, "KirinGPU"}, {kAscend910, "Ascend910"}, {kAscend310, "Ascend310"},
27   };
28   auto iter = kDeviceTypeStrs.find(type);
29   if (iter != kDeviceTypeStrs.end()) {
30     return iter->second;
31   }
32 
33   return "InvalidDeviceType" + std::to_string(static_cast<int>(type));
34 }
35 }  // namespace
Build(GraphCell graph_cell,const std::shared_ptr<Context> & model_context,const std::shared_ptr<TrainCfg> &)36 Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_context,
37                     const std::shared_ptr<TrainCfg> &) {
38   if (graph_cell.GetGraph() == nullptr) {
39     MS_LOG(ERROR) << "Invalid graph input.";
40     return kMCInvalidInput;
41   }
42 
43   if (model_context == nullptr) {
44     MS_LOG(ERROR) << "Invalid model context.";
45     return kMCInvalidInput;
46   }
47   auto &device_info = model_context->MutableDeviceInfo();
48   if (device_info.size() != 1) {
49     MS_LOG(ERROR) << "Invalid model context, only single device info is supported.";
50     return kMCInvalidInput;
51   }
52 
53   std::string device_target = GetDeviceTypeString(device_info[0]->GetDeviceType());
54   impl_ = Factory<ModelImpl>::Instance().Create(device_target);
55   if (impl_ == nullptr) {
56     MS_LOG(ERROR) << "Create session type " << device_target << " failed";
57     return kMEFailed;
58   }
59 
60   g_device_target = device_target;
61 
62   impl_->SetGraph(std::make_shared<Graph>(*graph_cell.GetGraph()));
63   impl_->SetContext(model_context);
64 
65   return impl_->Build();
66 }
67 
Build(const void *,size_t,ModelType,const std::shared_ptr<Context> &,const Key &,const std::string &)68 Status Model::Build(const void *, size_t, ModelType, const std::shared_ptr<Context> &, const Key &,
69                     const std::string &) {
70   MS_LOG(ERROR) << "Unsupported Feature.";
71   return kMCFailed;
72 }
73 
Build(const std::string &,ModelType,const std::shared_ptr<Context> &,const Key &,const std::string &)74 Status Model::Build(const std::string &, ModelType, const std::shared_ptr<Context> &, const Key &,
75                     const std::string &) {
76   MS_LOG(ERROR) << "Unsupported Feature.";
77   return kMCFailed;
78 }
79 
Resize(const std::vector<MSTensor> & inputs,const std::vector<std::vector<int64_t>> & dims)80 Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
81   if (impl_ == nullptr) {
82     MS_LOG(ERROR) << "Failed because this model has not been built.";
83     return kMCFailed;
84   }
85   return impl_->Resize(inputs, dims);
86 }
87 
Predict(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs,const MSKernelCallBack & before,const MSKernelCallBack & after)88 Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
89                       const MSKernelCallBack &before, const MSKernelCallBack &after) {
90   if (impl_ == nullptr) {
91     MS_LOG(ERROR) << "Failed because this model has not been built.";
92     return kMCFailed;
93   }
94   return impl_->Predict(inputs, outputs);
95 }
96 
PredictWithPreprocess(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs,const MSKernelCallBack & before,const MSKernelCallBack & after)97 Status Model::PredictWithPreprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
98                                     const MSKernelCallBack &before, const MSKernelCallBack &after) {
99   if (impl_ == nullptr) {
100     MS_LOG(ERROR) << "Failed because this model has not been built.";
101     return kMCFailed;
102   }
103   return impl_->PredictWithPreprocess(inputs, outputs);
104 }
105 
Preprocess(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs)106 Status Model::Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
107   if (impl_ == nullptr) {
108     MS_LOG(ERROR) << "Failed because this model has not been built.";
109     return kMCFailed;
110   }
111   return impl_->Preprocess(inputs, outputs);
112 }
113 
HasPreprocess()114 bool Model::HasPreprocess() {
115   if (impl_ == nullptr) {
116     MS_LOG(ERROR) << "Failed because this model has not been built.";
117     return false;
118   }
119   return impl_->HasPreprocess();
120 }
121 
GetInputs()122 std::vector<MSTensor> Model::GetInputs() {
123   if (impl_ == nullptr) {
124     MS_LOG(ERROR) << "Failed because this model has not been built.";
125     return {};
126   }
127   return impl_->GetInputs();
128 }
129 
GetOutputs()130 std::vector<MSTensor> Model::GetOutputs() {
131   if (impl_ == nullptr) {
132     MS_LOG(ERROR) << "Failed because this model has not been built.";
133     return {};
134   }
135   return impl_->GetOutputs();
136 }
137 
GetInputByTensorName(const std::vector<char> & tensor_name)138 MSTensor Model::GetInputByTensorName(const std::vector<char> &tensor_name) {
139   std::string tensor_name_str = CharToString(tensor_name);
140   auto inputs = GetInputs();
141   for (auto in : inputs) {
142     if (in.Name() == tensor_name_str) {
143       return in;
144     }
145   }
146 
147   return MSTensor(nullptr);
148 }
149 
GetOutputTensorNamesChar()150 std::vector<std::vector<char>> Model::GetOutputTensorNamesChar() {
151   std::vector<std::vector<char>> ret;
152   auto outputs = GetOutputs();
153   std::transform(outputs.begin(), outputs.end(), std::back_inserter(ret),
154                  [](const MSTensor &item) -> std::vector<char> { return StringToChar(item.Name()); });
155   return ret;
156 }
157 
GetOutputByTensorName(const std::vector<char> & tensor_name)158 MSTensor Model::GetOutputByTensorName(const std::vector<char> &tensor_name) {
159   std::string tensor_name_str = CharToString(tensor_name);
160   auto outputs = GetOutputs();
161   for (auto out : outputs) {
162     if (out.Name() == tensor_name_str) {
163       return out;
164     }
165   }
166 
167   return MSTensor(nullptr);
168 }
169 
GetOutputsByNodeName(const std::vector<char> & node_name)170 std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_name) {
171   return std::vector<MSTensor>{GetOutputByTensorName(node_name)};
172 }
173 
Model()174 Model::Model() : impl_(nullptr) {}
~Model()175 Model::~Model() {}
176 
CheckModelSupport(enum DeviceType device_type,ModelType model_type)177 bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
178   std::string device_type_str = GetDeviceTypeString(device_type);
179   if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) {
180     return false;
181   }
182 
183   auto check_model = Factory<ModelImpl>::Instance().Create(device_type_str);
184   if (check_model == nullptr) {
185     return false;
186   }
187 
188   return check_model->CheckModelSupport(model_type);
189 }
190 }  // namespace mindspore
191