1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ImageTensorGenerator.hpp"
7 #include "../InferenceTestImage.hpp"
8 #include <armnn/Logging.hpp>
9 #include <armnn/TypesUtils.hpp>
10 #include <armnnUtils/Filesystem.hpp>
11
12 #include <cxxopts/cxxopts.hpp>
13
14 #include <algorithm>
15 #include <fstream>
16 #include <iostream>
17 #include <string>
18
19 namespace
20 {
21
22 // parses the command line to extract
23 // * the input image file -i the input image file path (must exist)
24 // * the layout -l the data layout output generated with (optional - default value is NHWC)
25 // * the output file -o the output raw tensor file path (must not already exist)
26 class CommandLineProcessor
27 {
28 public:
ParseOptions(cxxopts::ParseResult & result)29 bool ParseOptions(cxxopts::ParseResult& result)
30 {
31 // infile is mandatory
32 if (result.count("infile"))
33 {
34 if (!ValidateInputFile(result["infile"].as<std::string>()))
35 {
36 return false;
37 }
38 }
39 else
40 {
41 std::cerr << "-i/--infile parameter is mandatory." << std::endl;
42 return false;
43 }
44
45 // model-format is mandatory
46 if (!result.count("model-format"))
47 {
48 std::cerr << "-f/--model-format parameter is mandatory." << std::endl;
49 return false;
50 }
51
52 // outfile is mandatory
53 if (result.count("outfile"))
54 {
55 if (!ValidateOutputFile(result["outfile"].as<std::string>()))
56 {
57 return false;
58 }
59 }
60 else
61 {
62 std::cerr << "-o/--outfile parameter is mandatory." << std::endl;
63 return false;
64 }
65
66 if (result.count("layout"))
67 {
68 if(!ValidateLayout(result["layout"].as<std::string>()))
69 {
70 return false;
71 }
72 }
73
74 return true;
75 }
76
ValidateInputFile(const std::string & inputFileName)77 bool ValidateInputFile(const std::string& inputFileName)
78 {
79 if (inputFileName.empty())
80 {
81 std::cerr << "No input file name specified" << std::endl;
82 return false;
83 }
84
85 if (!fs::exists(inputFileName))
86 {
87 std::cerr << "Input file [" << inputFileName << "] does not exist" << std::endl;
88 return false;
89 }
90
91 if (fs::is_directory(inputFileName))
92 {
93 std::cerr << "Input file [" << inputFileName << "] is a directory" << std::endl;
94 return false;
95 }
96
97 return true;
98 }
99
ValidateLayout(const std::string & layout)100 bool ValidateLayout(const std::string& layout)
101 {
102 if (layout.empty())
103 {
104 std::cerr << "No layout specified" << std::endl;
105 return false;
106 }
107
108 std::vector<std::string> supportedLayouts = { "NHWC", "NCHW" };
109
110 auto iterator = std::find(supportedLayouts.begin(), supportedLayouts.end(), layout);
111 if (iterator == supportedLayouts.end())
112 {
113 std::cerr << "Layout [" << layout << "] is not supported" << std::endl;
114 return false;
115 }
116
117 return true;
118 }
119
ValidateOutputFile(const std::string & outputFileName)120 bool ValidateOutputFile(const std::string& outputFileName)
121 {
122 if (outputFileName.empty())
123 {
124 std::cerr << "No output file name specified" << std::endl;
125 return false;
126 }
127
128 if (fs::exists(outputFileName))
129 {
130 std::cerr << "Output file [" << outputFileName << "] already exists" << std::endl;
131 return false;
132 }
133
134 if (fs::is_directory(outputFileName))
135 {
136 std::cerr << "Output file [" << outputFileName << "] is a directory" << std::endl;
137 return false;
138 }
139
140 fs::path outputPath(outputFileName);
141 if (!fs::exists(outputPath.parent_path()))
142 {
143 std::cerr << "Output directory [" << outputPath.parent_path().c_str() << "] does not exist" << std::endl;
144 return false;
145 }
146
147 return true;
148 }
149
ProcessCommandLine(int argc,char * argv[])150 bool ProcessCommandLine(int argc, char* argv[])
151 {
152 cxxopts::Options options("ImageTensorGenerator",
153 "Program for pre-processing a .jpg image "
154 "before generating a .raw tensor file from it.");
155
156 try
157 {
158 options.add_options()
159 ("h,help", "Display help messages")
160 ("i,infile",
161 "Input image file to generate tensor from",
162 cxxopts::value<std::string>(m_InputFileName))
163 ("f,model-format",
164 "Format of the intended model file that uses the images."
165 "Different formats have different image normalization styles."
166 "If unset, defaults to tflite."
167 "Accepted value (tflite)",
168 cxxopts::value<std::string>(m_ModelFormat)->default_value("tflite"))
169 ("o,outfile",
170 "Output raw tensor file path",
171 cxxopts::value<std::string>(m_OutputFileName))
172 ("z,output-type",
173 "The data type of the output tensors."
174 "If unset, defaults to \"float\" for all defined inputs. "
175 "Accepted values (float, int, qasymms8 or qasymmu8)",
176 cxxopts::value<std::string>(m_OutputType)->default_value("float"))
177 ("new-width",
178 "Resize image to new width. Keep original width if unspecified",
179 cxxopts::value<std::string>(m_NewWidth)->default_value("0"))
180 ("new-height",
181 "Resize image to new height. Keep original height if unspecified",
182 cxxopts::value<std::string>(m_NewHeight)->default_value("0"))
183 ("l,layout",
184 "Output data layout, \"NHWC\" or \"NCHW\", default value NHWC",
185 cxxopts::value<std::string>(m_Layout)->default_value("NHWC"));
186 }
187 catch (const std::exception& e)
188 {
189 std::cerr << options.help() << std::endl;
190 return false;
191 }
192
193 try
194 {
195 auto result = options.parse(argc, argv);
196
197 if (result.count("help"))
198 {
199 std::cout << options.help() << std::endl;
200 return false;
201 }
202
203 // Check for mandatory parameters and validate inputs
204 if(!ParseOptions(result)){
205 return false;
206 }
207 }
208 catch (const cxxopts::OptionException& e)
209 {
210 std::cerr << e.what() << std::endl << std::endl;
211 return false;
212 }
213
214 return true;
215 }
216
GetInputFileName()217 std::string GetInputFileName() {return m_InputFileName;}
GetLayout()218 armnn::DataLayout GetLayout()
219 {
220 if (m_Layout == "NHWC")
221 {
222 return armnn::DataLayout::NHWC;
223 }
224 else if (m_Layout == "NCHW")
225 {
226 return armnn::DataLayout::NCHW;
227 }
228 else
229 {
230 throw armnn::Exception("Unsupported data layout: " + m_Layout);
231 }
232 }
GetOutputFileName()233 std::string GetOutputFileName() {return m_OutputFileName;}
GetNewWidth()234 unsigned int GetNewWidth() {return static_cast<unsigned int>(std::stoi(m_NewWidth));}
GetNewHeight()235 unsigned int GetNewHeight() {return static_cast<unsigned int>(std::stoi(m_NewHeight));}
GetModelFormat()236 SupportedFrontend GetModelFormat()
237 {
238 if (m_ModelFormat == "tflite")
239 {
240 return SupportedFrontend::TFLite;
241 }
242 else
243 {
244 throw armnn::Exception("Unsupported model format" + m_ModelFormat);
245 }
246 }
GetOutputType()247 armnn::DataType GetOutputType()
248 {
249 if (m_OutputType == "float")
250 {
251 return armnn::DataType::Float32;
252 }
253 else if (m_OutputType == "int")
254 {
255 return armnn::DataType::Signed32;
256 }
257 else if (m_OutputType == "qasymm8" || m_OutputType == "qasymmu8")
258 {
259 return armnn::DataType::QAsymmU8;
260 }
261 else if (m_OutputType == "qasymms8")
262 {
263 return armnn::DataType::QAsymmS8;
264 }
265 else
266 {
267 throw armnn::Exception("Unsupported input type " + m_OutputType);
268 }
269 }
270
271 private:
272 std::string m_InputFileName;
273 std::string m_Layout;
274 std::string m_OutputFileName;
275 std::string m_NewWidth;
276 std::string m_NewHeight;
277 std::string m_ModelFormat;
278 std::string m_OutputType;
279 };
280
281 } // namespace anonymous
282
main(int argc,char * argv[])283 int main(int argc, char* argv[])
284 {
285 CommandLineProcessor cmdline;
286 if (!cmdline.ProcessCommandLine(argc, argv))
287 {
288 return -1;
289 }
290 const std::string imagePath(cmdline.GetInputFileName());
291 const std::string outputPath(cmdline.GetOutputFileName());
292 const SupportedFrontend& modelFormat(cmdline.GetModelFormat());
293 const armnn::DataType outputType(cmdline.GetOutputType());
294 const unsigned int newWidth = cmdline.GetNewWidth();
295 const unsigned int newHeight = cmdline.GetNewHeight();
296 const unsigned int batchSize = 1;
297 const armnn::DataLayout outputLayout(cmdline.GetLayout());
298
299 std::vector<armnnUtils::TContainer> imageDataContainers;
300 const NormalizationParameters& normParams = GetNormalizationParameters(modelFormat, outputType);
301 try
302 {
303 switch (outputType)
304 {
305 case armnn::DataType::Signed32:
306 imageDataContainers.push_back(PrepareImageTensor<int>(
307 imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
308 break;
309 case armnn::DataType::QAsymmU8:
310 imageDataContainers.push_back(PrepareImageTensor<uint8_t>(
311 imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
312 break;
313 case armnn::DataType::QAsymmS8:
314 imageDataContainers.push_back(PrepareImageTensor<int8_t>(
315 imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
316 break;
317 case armnn::DataType::Float32:
318 default:
319 imageDataContainers.push_back(PrepareImageTensor<float>(
320 imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
321 break;
322 }
323 }
324 catch (const InferenceTestImageException& e)
325 {
326 ARMNN_LOG(fatal) << "Failed to load image file " << imagePath << " with error: " << e.what();
327 return -1;
328 }
329
330 std::ofstream imageTensorFile;
331 imageTensorFile.open(outputPath, std::ofstream::out);
332 if (imageTensorFile.is_open())
333 {
334 mapbox::util::apply_visitor(
335 [&imageTensorFile](auto&& imageData){ WriteImageTensorImpl(imageData,imageTensorFile); },
336 imageDataContainers[0]
337 );
338
339 if (!imageTensorFile)
340 {
341 ARMNN_LOG(fatal) << "Failed to write to output file" << outputPath;
342 imageTensorFile.close();
343 return -1;
344 }
345 imageTensorFile.close();
346 }
347 else
348 {
349 ARMNN_LOG(fatal) << "Failed to open output file" << outputPath;
350 return -1;
351 }
352
353 return 0;
354 }
355