• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <getopt.h>
18 #include <string>
19 #include <iostream>
20 #include <fstream>
21 #include "src/utils.h"
22 #include "include/api/model.h"
23 #include "include/api/context.h"
24 #include "include/api/graph.h"
25 #include "include/api/serialization.h"
26 
Usage()27 static void Usage() { std::cout << "Usage: infer -f <.ms model file>" << std::endl; }
28 
ReadArgs(int argc,char * argv[])29 static std::string ReadArgs(int argc, char *argv[]) {
30   std::string infer_model_fn;
31   int opt;
32   while ((opt = getopt(argc, argv, "f:")) != -1) {
33     switch (opt) {
34       case 'f':
35         infer_model_fn = std::string(optarg);
36         break;
37       default:
38         break;
39     }
40   }
41   return infer_model_fn;
42 }
43 
main(int argc,char ** argv)44 int main(int argc, char **argv) {
45   std::string infer_model_fn = ReadArgs(argc, argv);
46   if (infer_model_fn.size() == 0) {
47     Usage();
48     return -1;
49   }
50 
51   auto context = std::make_shared<mindspore::Context>();
52   auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
53   cpu_context->SetEnableFP16(false);
54   context->MutableDeviceInfo().push_back(cpu_context);
55 
56   mindspore::Graph graph;
57   auto status = mindspore::Serialization::Load(infer_model_fn, mindspore::kMindIR, &graph);
58   if (status != mindspore::kSuccess) {
59     std::cout << "Error " << status << " during serialization of graph " << infer_model_fn;
60     MS_ASSERT(status != mindspore::kSuccess);
61   }
62 
63   mindspore::Model model;
64   status = model.Build(mindspore::GraphCell(graph), context);
65   if (status != mindspore::kSuccess) {
66     std::cout << "Error " << status << " during build of model " << infer_model_fn;
67     MS_ASSERT(status != mindspore::kSuccess);
68   }
69 
70   auto inputs = model.GetInputs();
71   MS_ASSERT(inputs.size() >= 1);
72 
73   int index = 0;
74   std::cout << "There are " << inputs.size() << " input tensors with sizes: " << std::endl;
75   for (auto tensor : inputs) {
76     std::cout << "tensor " << index++ << ": shape is [";
77     for (auto dim : tensor.Shape()) {
78       std::cout << dim << " ";
79     }
80     std::cout << "]" << std::endl;
81   }
82 
83   mindspore::MSTensor *input_tensor = inputs.at(0).Clone();
84   auto *input_data = reinterpret_cast<float *>(input_tensor->MutableData());
85   std::ifstream in;
86   in.open("dataset/batch_of32.dat", std::ios::in | std::ios::binary);
87   if (in.fail()) {
88     std::cout << "error loading dataset/batch_of32.dat file reading" << std::endl;
89     MS_ASSERT(!in.fail());
90   }
91   in.read(reinterpret_cast<char *>(input_data), inputs.at(0).ElementNum() * sizeof(float));
92   in.close();
93 
94   std::vector<mindspore::MSTensor> outputs;
95   status = model.Predict({*input_tensor}, &outputs);
96   if (status != mindspore::kSuccess) {
97     std::cout << "Error " << status << " during running predict of model " << infer_model_fn;
98     MS_ASSERT(status != mindspore::kSuccess);
99   }
100 
101   index = 0;
102   std::cout << "There are " << outputs.size() << " output tensors with sizes: " << std::endl;
103   for (auto tensor : outputs) {
104     std::cout << "tensor " << index++ << ": shape is [";
105     for (auto dim : tensor.Shape()) {
106       std::cout << dim << " ";
107     }
108     std::cout << "]" << std::endl;
109   }
110 
111   if (outputs.size() > 0) {
112     std::cout << "The predicted classes are:" << std::endl;
113     auto predictions = reinterpret_cast<float *>(outputs.at(0).MutableData());
114     int i = 0;
115     for (int b = 0; b < outputs.at(0).Shape().at(0); b++) {
116       int max_c = 0;
117       float max_p = predictions[i];
118       for (int c = 0; c < outputs.at(0).Shape().at(1); c++, i++) {
119         if (predictions[i] > max_p) {
120           max_c = c;
121           max_p = predictions[i];
122         }
123       }
124       std::cout << max_c << ", ";
125     }
126     std::cout << std::endl;
127   }
128   return 0;
129 }
130