1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <getopt.h>
18 #include <string>
19 #include <iostream>
20 #include <fstream>
21 #include "src/utils.h"
22 #include "include/api/model.h"
23 #include "include/api/context.h"
24 #include "include/api/graph.h"
25 #include "include/api/serialization.h"
26
Usage()27 static void Usage() { std::cout << "Usage: infer -f <.ms model file>" << std::endl; }
28
ReadArgs(int argc,char * argv[])29 static std::string ReadArgs(int argc, char *argv[]) {
30 std::string infer_model_fn;
31 int opt;
32 while ((opt = getopt(argc, argv, "f:")) != -1) {
33 switch (opt) {
34 case 'f':
35 infer_model_fn = std::string(optarg);
36 break;
37 default:
38 break;
39 }
40 }
41 return infer_model_fn;
42 }
43
main(int argc,char ** argv)44 int main(int argc, char **argv) {
45 std::string infer_model_fn = ReadArgs(argc, argv);
46 if (infer_model_fn.size() == 0) {
47 Usage();
48 return -1;
49 }
50
51 auto context = std::make_shared<mindspore::Context>();
52 auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
53 cpu_context->SetEnableFP16(false);
54 context->MutableDeviceInfo().push_back(cpu_context);
55
56 mindspore::Graph graph;
57 auto status = mindspore::Serialization::Load(infer_model_fn, mindspore::kMindIR, &graph);
58 if (status != mindspore::kSuccess) {
59 std::cout << "Error " << status << " during serialization of graph " << infer_model_fn;
60 MS_ASSERT(status != mindspore::kSuccess);
61 }
62
63 mindspore::Model model;
64 status = model.Build(mindspore::GraphCell(graph), context);
65 if (status != mindspore::kSuccess) {
66 std::cout << "Error " << status << " during build of model " << infer_model_fn;
67 MS_ASSERT(status != mindspore::kSuccess);
68 }
69
70 auto inputs = model.GetInputs();
71 MS_ASSERT(inputs.size() >= 1);
72
73 int index = 0;
74 std::cout << "There are " << inputs.size() << " input tensors with sizes: " << std::endl;
75 for (auto tensor : inputs) {
76 std::cout << "tensor " << index++ << ": shape is [";
77 for (auto dim : tensor.Shape()) {
78 std::cout << dim << " ";
79 }
80 std::cout << "]" << std::endl;
81 }
82
83 mindspore::MSTensor *input_tensor = inputs.at(0).Clone();
84 auto *input_data = reinterpret_cast<float *>(input_tensor->MutableData());
85 std::ifstream in;
86 in.open("dataset/batch_of32.dat", std::ios::in | std::ios::binary);
87 if (in.fail()) {
88 std::cout << "error loading dataset/batch_of32.dat file reading" << std::endl;
89 MS_ASSERT(!in.fail());
90 }
91 in.read(reinterpret_cast<char *>(input_data), inputs.at(0).ElementNum() * sizeof(float));
92 in.close();
93
94 std::vector<mindspore::MSTensor> outputs;
95 status = model.Predict({*input_tensor}, &outputs);
96 if (status != mindspore::kSuccess) {
97 std::cout << "Error " << status << " during running predict of model " << infer_model_fn;
98 MS_ASSERT(status != mindspore::kSuccess);
99 }
100
101 index = 0;
102 std::cout << "There are " << outputs.size() << " output tensors with sizes: " << std::endl;
103 for (auto tensor : outputs) {
104 std::cout << "tensor " << index++ << ": shape is [";
105 for (auto dim : tensor.Shape()) {
106 std::cout << dim << " ";
107 }
108 std::cout << "]" << std::endl;
109 }
110
111 if (outputs.size() > 0) {
112 std::cout << "The predicted classes are:" << std::endl;
113 auto predictions = reinterpret_cast<float *>(outputs.at(0).MutableData());
114 int i = 0;
115 for (int b = 0; b < outputs.at(0).Shape().at(0); b++) {
116 int max_c = 0;
117 float max_p = predictions[i];
118 for (int c = 0; c < outputs.at(0).Shape().at(1); c++, i++) {
119 if (predictions[i] > max_p) {
120 max_c = c;
121 max_p = predictions[i];
122 }
123 }
124 std::cout << max_c << ", ";
125 }
126 std::cout << std::endl;
127 }
128 return 0;
129 }
130