1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/api/model.h"
17 #include "include/api/context.h"
18 #include "cxx_api/model/model_impl.h"
19 #include "cxx_api/factory.h"
20 #include "utils/utils.h"
21
22 namespace mindspore {
23 namespace {
GetDeviceTypeString(enum DeviceType type)24 std::string GetDeviceTypeString(enum DeviceType type) {
25 static const std::map<enum DeviceType, std::string> kDeviceTypeStrs = {
26 {kCPU, "CPU"}, {kGPU, "GPU"}, {kKirinNPU, "KirinGPU"}, {kAscend910, "Ascend910"}, {kAscend310, "Ascend310"},
27 };
28 auto iter = kDeviceTypeStrs.find(type);
29 if (iter != kDeviceTypeStrs.end()) {
30 return iter->second;
31 }
32
33 return "InvalidDeviceType" + std::to_string(static_cast<int>(type));
34 }
35 } // namespace
Build(GraphCell graph_cell,const std::shared_ptr<Context> & model_context,const std::shared_ptr<TrainCfg> &)36 Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_context,
37 const std::shared_ptr<TrainCfg> &) {
38 if (graph_cell.GetGraph() == nullptr) {
39 MS_LOG(ERROR) << "Invalid graph input.";
40 return kMCInvalidInput;
41 }
42
43 if (model_context == nullptr) {
44 MS_LOG(ERROR) << "Invalid model context.";
45 return kMCInvalidInput;
46 }
47 auto &device_info = model_context->MutableDeviceInfo();
48 if (device_info.size() != 1) {
49 MS_LOG(ERROR) << "Invalid model context, only single device info is supported.";
50 return kMCInvalidInput;
51 }
52
53 std::string device_target = GetDeviceTypeString(device_info[0]->GetDeviceType());
54 impl_ = Factory<ModelImpl>::Instance().Create(device_target);
55 if (impl_ == nullptr) {
56 MS_LOG(ERROR) << "Create session type " << device_target << " failed";
57 return kMEFailed;
58 }
59
60 g_device_target = device_target;
61
62 impl_->SetGraph(std::make_shared<Graph>(*graph_cell.GetGraph()));
63 impl_->SetContext(model_context);
64
65 return impl_->Build();
66 }
67
Build(const void *,size_t,ModelType,const std::shared_ptr<Context> &,const Key &,const std::string &)68 Status Model::Build(const void *, size_t, ModelType, const std::shared_ptr<Context> &, const Key &,
69 const std::string &) {
70 MS_LOG(ERROR) << "Unsupported Feature.";
71 return kMCFailed;
72 }
73
Build(const std::string &,ModelType,const std::shared_ptr<Context> &,const Key &,const std::string &)74 Status Model::Build(const std::string &, ModelType, const std::shared_ptr<Context> &, const Key &,
75 const std::string &) {
76 MS_LOG(ERROR) << "Unsupported Feature.";
77 return kMCFailed;
78 }
79
Resize(const std::vector<MSTensor> & inputs,const std::vector<std::vector<int64_t>> & dims)80 Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
81 if (impl_ == nullptr) {
82 MS_LOG(ERROR) << "Failed because this model has not been built.";
83 return kMCFailed;
84 }
85 return impl_->Resize(inputs, dims);
86 }
87
Predict(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs,const MSKernelCallBack & before,const MSKernelCallBack & after)88 Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
89 const MSKernelCallBack &before, const MSKernelCallBack &after) {
90 if (impl_ == nullptr) {
91 MS_LOG(ERROR) << "Failed because this model has not been built.";
92 return kMCFailed;
93 }
94 return impl_->Predict(inputs, outputs);
95 }
96
PredictWithPreprocess(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs,const MSKernelCallBack & before,const MSKernelCallBack & after)97 Status Model::PredictWithPreprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
98 const MSKernelCallBack &before, const MSKernelCallBack &after) {
99 if (impl_ == nullptr) {
100 MS_LOG(ERROR) << "Failed because this model has not been built.";
101 return kMCFailed;
102 }
103 return impl_->PredictWithPreprocess(inputs, outputs);
104 }
105
Preprocess(const std::vector<MSTensor> & inputs,std::vector<MSTensor> * outputs)106 Status Model::Preprocess(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
107 if (impl_ == nullptr) {
108 MS_LOG(ERROR) << "Failed because this model has not been built.";
109 return kMCFailed;
110 }
111 return impl_->Preprocess(inputs, outputs);
112 }
113
HasPreprocess()114 bool Model::HasPreprocess() {
115 if (impl_ == nullptr) {
116 MS_LOG(ERROR) << "Failed because this model has not been built.";
117 return false;
118 }
119 return impl_->HasPreprocess();
120 }
121
GetInputs()122 std::vector<MSTensor> Model::GetInputs() {
123 if (impl_ == nullptr) {
124 MS_LOG(ERROR) << "Failed because this model has not been built.";
125 return {};
126 }
127 return impl_->GetInputs();
128 }
129
GetOutputs()130 std::vector<MSTensor> Model::GetOutputs() {
131 if (impl_ == nullptr) {
132 MS_LOG(ERROR) << "Failed because this model has not been built.";
133 return {};
134 }
135 return impl_->GetOutputs();
136 }
137
GetInputByTensorName(const std::vector<char> & tensor_name)138 MSTensor Model::GetInputByTensorName(const std::vector<char> &tensor_name) {
139 std::string tensor_name_str = CharToString(tensor_name);
140 auto inputs = GetInputs();
141 for (auto in : inputs) {
142 if (in.Name() == tensor_name_str) {
143 return in;
144 }
145 }
146
147 return MSTensor(nullptr);
148 }
149
GetOutputTensorNamesChar()150 std::vector<std::vector<char>> Model::GetOutputTensorNamesChar() {
151 std::vector<std::vector<char>> ret;
152 auto outputs = GetOutputs();
153 std::transform(outputs.begin(), outputs.end(), std::back_inserter(ret),
154 [](const MSTensor &item) -> std::vector<char> { return StringToChar(item.Name()); });
155 return ret;
156 }
157
GetOutputByTensorName(const std::vector<char> & tensor_name)158 MSTensor Model::GetOutputByTensorName(const std::vector<char> &tensor_name) {
159 std::string tensor_name_str = CharToString(tensor_name);
160 auto outputs = GetOutputs();
161 for (auto out : outputs) {
162 if (out.Name() == tensor_name_str) {
163 return out;
164 }
165 }
166
167 return MSTensor(nullptr);
168 }
169
GetOutputsByNodeName(const std::vector<char> & node_name)170 std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_name) {
171 return std::vector<MSTensor>{GetOutputByTensorName(node_name)};
172 }
173
Model()174 Model::Model() : impl_(nullptr) {}
~Model()175 Model::~Model() {}
176
CheckModelSupport(enum DeviceType device_type,ModelType model_type)177 bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) {
178 std::string device_type_str = GetDeviceTypeString(device_type);
179 if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) {
180 return false;
181 }
182
183 auto check_model = Factory<ModelImpl>::Instance().Create(device_type_str);
184 if (check_model == nullptr) {
185 return false;
186 }
187
188 return check_model->CheckModelSupport(model_type);
189 }
190 } // namespace mindspore
191