• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ImageTensorGenerator.hpp"
7 #include "../InferenceTestImage.hpp"
8 #include <armnn/Logging.hpp>
9 #include <armnn/TypesUtils.hpp>
10 #include <armnnUtils/Filesystem.hpp>
11 
12 #include <cxxopts/cxxopts.hpp>
13 
14 #include <algorithm>
15 #include <fstream>
16 #include <iostream>
17 #include <string>
18 
19 namespace
20 {
21 
22 // parses the command line to extract
23 // * the input image file -i the input image file path (must exist)
24 // * the layout -l the data layout output generated with (optional - default value is NHWC)
25 // * the output file -o the output raw tensor file path (must not already exist)
26 class CommandLineProcessor
27 {
28 public:
ParseOptions(cxxopts::ParseResult & result)29     bool ParseOptions(cxxopts::ParseResult& result)
30     {
31         // infile is mandatory
32         if (result.count("infile"))
33         {
34             if (!ValidateInputFile(result["infile"].as<std::string>()))
35             {
36                 return false;
37             }
38         }
39         else
40         {
41             std::cerr << "-i/--infile parameter is mandatory." << std::endl;
42             return false;
43         }
44 
45         // model-format is mandatory
46         if (!result.count("model-format"))
47         {
48             std::cerr << "-f/--model-format parameter is mandatory." << std::endl;
49             return false;
50         }
51 
52         // outfile is mandatory
53         if (result.count("outfile"))
54         {
55             if (!ValidateOutputFile(result["outfile"].as<std::string>()))
56             {
57                 return false;
58             }
59         }
60         else
61         {
62             std::cerr << "-o/--outfile parameter is mandatory." << std::endl;
63             return false;
64         }
65 
66         if (result.count("layout"))
67         {
68             if(!ValidateLayout(result["layout"].as<std::string>()))
69             {
70                 return false;
71             }
72         }
73 
74         return true;
75     }
76 
ValidateInputFile(const std::string & inputFileName)77     bool ValidateInputFile(const std::string& inputFileName)
78     {
79         if (inputFileName.empty())
80         {
81             std::cerr << "No input file name specified" << std::endl;
82             return false;
83         }
84 
85         if (!fs::exists(inputFileName))
86         {
87             std::cerr << "Input file [" << inputFileName << "] does not exist" << std::endl;
88             return false;
89         }
90 
91         if (fs::is_directory(inputFileName))
92         {
93             std::cerr << "Input file [" << inputFileName << "] is a directory" << std::endl;
94             return false;
95         }
96 
97         return true;
98     }
99 
ValidateLayout(const std::string & layout)100     bool ValidateLayout(const std::string& layout)
101     {
102         if (layout.empty())
103         {
104             std::cerr << "No layout specified" << std::endl;
105             return false;
106         }
107 
108         std::vector<std::string> supportedLayouts = { "NHWC", "NCHW" };
109 
110         auto iterator = std::find(supportedLayouts.begin(), supportedLayouts.end(), layout);
111         if (iterator == supportedLayouts.end())
112         {
113             std::cerr << "Layout [" << layout << "] is not supported" << std::endl;
114             return false;
115         }
116 
117         return true;
118     }
119 
ValidateOutputFile(const std::string & outputFileName)120     bool ValidateOutputFile(const std::string& outputFileName)
121     {
122         if (outputFileName.empty())
123         {
124             std::cerr << "No output file name specified" << std::endl;
125             return false;
126         }
127 
128         if (fs::exists(outputFileName))
129         {
130             std::cerr << "Output file [" << outputFileName << "] already exists" << std::endl;
131             return false;
132         }
133 
134         if (fs::is_directory(outputFileName))
135         {
136             std::cerr << "Output file [" << outputFileName << "] is a directory" << std::endl;
137             return false;
138         }
139 
140         fs::path outputPath(outputFileName);
141         if (!fs::exists(outputPath.parent_path()))
142         {
143             std::cerr << "Output directory [" << outputPath.parent_path().c_str() << "] does not exist" << std::endl;
144             return false;
145         }
146 
147         return true;
148     }
149 
ProcessCommandLine(int argc,char * argv[])150     bool ProcessCommandLine(int argc, char* argv[])
151     {
152         cxxopts::Options options("ImageTensorGenerator",
153                                  "Program for pre-processing a .jpg image "
154                                  "before generating a .raw tensor file from it.");
155 
156         try
157         {
158             options.add_options()
159                 ("h,help", "Display help messages")
160                 ("i,infile",
161                     "Input image file to generate tensor from",
162                     cxxopts::value<std::string>(m_InputFileName))
163                 ("f,model-format",
164                     "Format of the intended model file that uses the images."
165                     "Different formats have different image normalization styles."
166                     "If unset, defaults to tflite."
167                     "Accepted value (tflite)",
168                     cxxopts::value<std::string>(m_ModelFormat)->default_value("tflite"))
169                 ("o,outfile",
170                     "Output raw tensor file path",
171                     cxxopts::value<std::string>(m_OutputFileName))
172                 ("z,output-type",
173                     "The data type of the output tensors."
174                     "If unset, defaults to \"float\" for all defined inputs. "
175                     "Accepted values (float, int, qasymms8 or qasymmu8)",
176                     cxxopts::value<std::string>(m_OutputType)->default_value("float"))
177                 ("new-width",
178                     "Resize image to new width. Keep original width if unspecified",
179                     cxxopts::value<std::string>(m_NewWidth)->default_value("0"))
180                 ("new-height",
181                     "Resize image to new height. Keep original height if unspecified",
182                     cxxopts::value<std::string>(m_NewHeight)->default_value("0"))
183                 ("l,layout",
184                     "Output data layout, \"NHWC\" or \"NCHW\", default value NHWC",
185                     cxxopts::value<std::string>(m_Layout)->default_value("NHWC"));
186         }
187         catch (const std::exception& e)
188         {
189             std::cerr << options.help() << std::endl;
190             return false;
191         }
192 
193         try
194         {
195             auto result = options.parse(argc, argv);
196 
197             if (result.count("help"))
198             {
199                 std::cout << options.help() << std::endl;
200                 return false;
201             }
202 
203             // Check for mandatory parameters and validate inputs
204             if(!ParseOptions(result)){
205                 return false;
206             }
207         }
208         catch (const cxxopts::OptionException& e)
209         {
210             std::cerr << e.what() << std::endl << std::endl;
211             return false;
212         }
213 
214         return true;
215     }
216 
GetInputFileName()217     std::string GetInputFileName() {return m_InputFileName;}
GetLayout()218     armnn::DataLayout GetLayout()
219     {
220         if (m_Layout == "NHWC")
221         {
222             return armnn::DataLayout::NHWC;
223         }
224         else if (m_Layout == "NCHW")
225         {
226             return armnn::DataLayout::NCHW;
227         }
228         else
229         {
230             throw armnn::Exception("Unsupported data layout: " + m_Layout);
231         }
232     }
GetOutputFileName()233     std::string GetOutputFileName() {return m_OutputFileName;}
GetNewWidth()234     unsigned int GetNewWidth() {return static_cast<unsigned int>(std::stoi(m_NewWidth));}
GetNewHeight()235     unsigned int GetNewHeight() {return static_cast<unsigned int>(std::stoi(m_NewHeight));}
GetModelFormat()236     SupportedFrontend GetModelFormat()
237     {
238         if (m_ModelFormat == "tflite")
239         {
240             return SupportedFrontend::TFLite;
241         }
242         else
243         {
244             throw armnn::Exception("Unsupported model format" + m_ModelFormat);
245         }
246     }
GetOutputType()247     armnn::DataType GetOutputType()
248     {
249         if (m_OutputType == "float")
250         {
251             return armnn::DataType::Float32;
252         }
253         else if (m_OutputType == "int")
254         {
255             return armnn::DataType::Signed32;
256         }
257         else if (m_OutputType == "qasymm8" || m_OutputType == "qasymmu8")
258         {
259             return armnn::DataType::QAsymmU8;
260         }
261         else if (m_OutputType == "qasymms8")
262         {
263             return armnn::DataType::QAsymmS8;
264         }
265         else
266         {
267             throw armnn::Exception("Unsupported input type " + m_OutputType);
268         }
269     }
270 
271 private:
272     std::string m_InputFileName;
273     std::string m_Layout;
274     std::string m_OutputFileName;
275     std::string m_NewWidth;
276     std::string m_NewHeight;
277     std::string m_ModelFormat;
278     std::string m_OutputType;
279 };
280 
281 } // namespace anonymous
282 
main(int argc,char * argv[])283 int main(int argc, char* argv[])
284 {
285     CommandLineProcessor cmdline;
286     if (!cmdline.ProcessCommandLine(argc, argv))
287     {
288         return -1;
289     }
290     const std::string imagePath(cmdline.GetInputFileName());
291     const std::string outputPath(cmdline.GetOutputFileName());
292     const SupportedFrontend& modelFormat(cmdline.GetModelFormat());
293     const armnn::DataType outputType(cmdline.GetOutputType());
294     const unsigned int newWidth  = cmdline.GetNewWidth();
295     const unsigned int newHeight = cmdline.GetNewHeight();
296     const unsigned int batchSize = 1;
297     const armnn::DataLayout outputLayout(cmdline.GetLayout());
298 
299     std::vector<armnnUtils::TContainer> imageDataContainers;
300     const NormalizationParameters& normParams = GetNormalizationParameters(modelFormat, outputType);
301     try
302     {
303         switch (outputType)
304         {
305             case armnn::DataType::Signed32:
306                 imageDataContainers.push_back(PrepareImageTensor<int>(
307                     imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
308                 break;
309             case armnn::DataType::QAsymmU8:
310                 imageDataContainers.push_back(PrepareImageTensor<uint8_t>(
311                     imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
312                 break;
313             case armnn::DataType::QAsymmS8:
314                 imageDataContainers.push_back(PrepareImageTensor<int8_t>(
315                         imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
316                 break;
317             case armnn::DataType::Float32:
318             default:
319                 imageDataContainers.push_back(PrepareImageTensor<float>(
320                     imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
321                 break;
322         }
323     }
324     catch (const InferenceTestImageException& e)
325     {
326         ARMNN_LOG(fatal) << "Failed to load image file " << imagePath << " with error: " << e.what();
327         return -1;
328     }
329 
330     std::ofstream imageTensorFile;
331     imageTensorFile.open(outputPath, std::ofstream::out);
332     if (imageTensorFile.is_open())
333     {
334         mapbox::util::apply_visitor(
335             [&imageTensorFile](auto&& imageData){ WriteImageTensorImpl(imageData,imageTensorFile); },
336             imageDataContainers[0]
337             );
338 
339         if (!imageTensorFile)
340         {
341             ARMNN_LOG(fatal) << "Failed to write to output file" << outputPath;
342             imageTensorFile.close();
343             return -1;
344         }
345         imageTensorFile.close();
346     }
347     else
348     {
349         ARMNN_LOG(fatal) << "Failed to open output file" << outputPath;
350         return -1;
351     }
352 
353     return 0;
354 }
355