• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ImageTensorGenerator.hpp"
7 #include "../InferenceTestImage.hpp"
8 #include <armnn/Logging.hpp>
9 #include <armnn/TypesUtils.hpp>
10 #include <Filesystem.hpp>
11 
12 #include <mapbox/variant.hpp>
13 #include <cxxopts/cxxopts.hpp>
14 
15 #include <algorithm>
16 #include <fstream>
17 #include <iostream>
18 #include <string>
19 
20 namespace
21 {
22 
23 // parses the command line to extract
24 // * the input image file -i the input image file path (must exist)
25 // * the layout -l the data layout output generated with (optional - default value is NHWC)
26 // * the output file -o the output raw tensor file path (must not already exist)
27 class CommandLineProcessor
28 {
29 public:
ParseOptions(cxxopts::ParseResult & result)30     bool ParseOptions(cxxopts::ParseResult& result)
31     {
32         // infile is mandatory
33         if (result.count("infile"))
34         {
35             if (!ValidateInputFile(result["infile"].as<std::string>()))
36             {
37                 return false;
38             }
39         }
40         else
41         {
42             std::cerr << "-i/--infile parameter is mandatory." << std::endl;
43             return false;
44         }
45 
46         // model-format is mandatory
47         if (!result.count("model-format"))
48         {
49             std::cerr << "-f/--model-format parameter is mandatory." << std::endl;
50             return false;
51         }
52 
53         // outfile is mandatory
54         if (result.count("outfile"))
55         {
56             if (!ValidateOutputFile(result["outfile"].as<std::string>()))
57             {
58                 return false;
59             }
60         }
61         else
62         {
63             std::cerr << "-o/--outfile parameter is mandatory." << std::endl;
64             return false;
65         }
66 
67         if (result.count("layout"))
68         {
69             if(!ValidateLayout(result["layout"].as<std::string>()))
70             {
71                 return false;
72             }
73         }
74 
75         return true;
76     }
77 
ValidateInputFile(const std::string & inputFileName)78     bool ValidateInputFile(const std::string& inputFileName)
79     {
80         if (inputFileName.empty())
81         {
82             std::cerr << "No input file name specified" << std::endl;
83             return false;
84         }
85 
86         if (!fs::exists(inputFileName))
87         {
88             std::cerr << "Input file [" << inputFileName << "] does not exist" << std::endl;
89             return false;
90         }
91 
92         if (fs::is_directory(inputFileName))
93         {
94             std::cerr << "Input file [" << inputFileName << "] is a directory" << std::endl;
95             return false;
96         }
97 
98         return true;
99     }
100 
ValidateLayout(const std::string & layout)101     bool ValidateLayout(const std::string& layout)
102     {
103         if (layout.empty())
104         {
105             std::cerr << "No layout specified" << std::endl;
106             return false;
107         }
108 
109         std::vector<std::string> supportedLayouts = { "NHWC", "NCHW" };
110 
111         auto iterator = std::find(supportedLayouts.begin(), supportedLayouts.end(), layout);
112         if (iterator == supportedLayouts.end())
113         {
114             std::cerr << "Layout [" << layout << "] is not supported" << std::endl;
115             return false;
116         }
117 
118         return true;
119     }
120 
ValidateOutputFile(const std::string & outputFileName)121     bool ValidateOutputFile(const std::string& outputFileName)
122     {
123         if (outputFileName.empty())
124         {
125             std::cerr << "No output file name specified" << std::endl;
126             return false;
127         }
128 
129         if (fs::exists(outputFileName))
130         {
131             std::cerr << "Output file [" << outputFileName << "] already exists" << std::endl;
132             return false;
133         }
134 
135         if (fs::is_directory(outputFileName))
136         {
137             std::cerr << "Output file [" << outputFileName << "] is a directory" << std::endl;
138             return false;
139         }
140 
141         fs::path outputPath(outputFileName);
142         if (!fs::exists(outputPath.parent_path()))
143         {
144             std::cerr << "Output directory [" << outputPath.parent_path().c_str() << "] does not exist" << std::endl;
145             return false;
146         }
147 
148         return true;
149     }
150 
ProcessCommandLine(int argc,char * argv[])151     bool ProcessCommandLine(int argc, char* argv[])
152     {
153         cxxopts::Options options("ImageTensorGenerator",
154                                  "Program for pre-processing a .jpg image "
155                                  "before generating a .raw tensor file from it.");
156 
157         try
158         {
159             options.add_options()
160                 ("h,help", "Display help messages")
161                 ("i,infile",
162                     "Input image file to generate tensor from",
163                     cxxopts::value<std::string>(m_InputFileName))
164                 ("f,model-format",
165                     "Format of the intended model file that uses the images."
166                     "Different formats have different image normalization styles."
167                     "Accepted values (caffe, tensorflow, tflite)",
168                     cxxopts::value<std::string>(m_ModelFormat))
169                 ("o,outfile",
170                     "Output raw tensor file path",
171                     cxxopts::value<std::string>(m_OutputFileName))
172                 ("z,output-type",
173                     "The data type of the output tensors."
174                     "If unset, defaults to \"float\" for all defined inputs. "
175                     "Accepted values (float, int or qasymm8)",
176                     cxxopts::value<std::string>(m_OutputType)->default_value("float"))
177                 ("new-width",
178                     "Resize image to new width. Keep original width if unspecified",
179                     cxxopts::value<std::string>(m_NewWidth)->default_value("0"))
180                 ("new-height",
181                     "Resize image to new height. Keep original height if unspecified",
182                     cxxopts::value<std::string>(m_NewHeight)->default_value("0"))
183                 ("l,layout",
184                     "Output data layout, \"NHWC\" or \"NCHW\", default value NHWC",
185                     cxxopts::value<std::string>(m_Layout)->default_value("NHWC"));
186         }
187         catch (const std::exception& e)
188         {
189             std::cerr << options.help() << std::endl;
190             return false;
191         }
192 
193         try
194         {
195             auto result = options.parse(argc, argv);
196 
197             if (result.count("help"))
198             {
199                 std::cout << options.help() << std::endl;
200                 return false;
201             }
202 
203             // Check for mandatory parameters and validate inputs
204             if(!ParseOptions(result)){
205                 return false;
206             }
207         }
208         catch (const cxxopts::OptionException& e)
209         {
210             std::cerr << e.what() << std::endl << std::endl;
211             return false;
212         }
213 
214         return true;
215     }
216 
GetInputFileName()217     std::string GetInputFileName() {return m_InputFileName;}
GetLayout()218     armnn::DataLayout GetLayout()
219     {
220         if (m_Layout == "NHWC")
221         {
222             return armnn::DataLayout::NHWC;
223         }
224         else if (m_Layout == "NCHW")
225         {
226             return armnn::DataLayout::NCHW;
227         }
228         else
229         {
230             throw armnn::Exception("Unsupported data layout: " + m_Layout);
231         }
232     }
GetOutputFileName()233     std::string GetOutputFileName() {return m_OutputFileName;}
GetNewWidth()234     unsigned int GetNewWidth() {return static_cast<unsigned int>(std::stoi(m_NewWidth));}
GetNewHeight()235     unsigned int GetNewHeight() {return static_cast<unsigned int>(std::stoi(m_NewHeight));}
GetModelFormat()236     SupportedFrontend GetModelFormat()
237     {
238         if (m_ModelFormat == "caffe")
239         {
240             return SupportedFrontend::Caffe;
241         }
242         else if (m_ModelFormat == "tensorflow")
243         {
244             return SupportedFrontend::TensorFlow;
245         }
246         else if (m_ModelFormat == "tflite")
247         {
248             return SupportedFrontend::TFLite;
249         }
250         else
251         {
252             throw armnn::Exception("Unsupported model format" + m_ModelFormat);
253         }
254     }
GetOutputType()255     armnn::DataType GetOutputType()
256     {
257         if (m_OutputType == "float")
258         {
259             return armnn::DataType::Float32;
260         }
261         else if (m_OutputType == "int")
262         {
263             return armnn::DataType::Signed32;
264         }
265         else if (m_OutputType == "qasymm8")
266         {
267             return armnn::DataType::QAsymmU8;
268         }
269         else
270         {
271             throw armnn::Exception("Unsupported input type" + m_OutputType);
272         }
273     }
274 
275 private:
276     std::string m_InputFileName;
277     std::string m_Layout;
278     std::string m_OutputFileName;
279     std::string m_NewWidth;
280     std::string m_NewHeight;
281     std::string m_ModelFormat;
282     std::string m_OutputType;
283 };
284 
285 } // namespace anonymous
286 
main(int argc,char * argv[])287 int main(int argc, char* argv[])
288 {
289     CommandLineProcessor cmdline;
290     if (!cmdline.ProcessCommandLine(argc, argv))
291     {
292         return -1;
293     }
294     const std::string imagePath(cmdline.GetInputFileName());
295     const std::string outputPath(cmdline.GetOutputFileName());
296     const SupportedFrontend& modelFormat(cmdline.GetModelFormat());
297     const armnn::DataType outputType(cmdline.GetOutputType());
298     const unsigned int newWidth  = cmdline.GetNewWidth();
299     const unsigned int newHeight = cmdline.GetNewHeight();
300     const unsigned int batchSize = 1;
301     const armnn::DataLayout outputLayout(cmdline.GetLayout());
302 
303     using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<uint8_t>>;
304     std::vector<TContainer> imageDataContainers;
305     const NormalizationParameters& normParams = GetNormalizationParameters(modelFormat, outputType);
306     try
307     {
308         switch (outputType)
309         {
310             case armnn::DataType::Signed32:
311                 imageDataContainers.push_back(PrepareImageTensor<int>(
312                     imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
313                 break;
314             case armnn::DataType::QAsymmU8:
315                 imageDataContainers.push_back(PrepareImageTensor<uint8_t>(
316                     imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
317                 break;
318             case armnn::DataType::Float32:
319             default:
320                 imageDataContainers.push_back(PrepareImageTensor<float>(
321                     imagePath, newWidth, newHeight, normParams, batchSize, outputLayout));
322                 break;
323         }
324     }
325     catch (const InferenceTestImageException& e)
326     {
327         ARMNN_LOG(fatal) << "Failed to load image file " << imagePath << " with error: " << e.what();
328         return -1;
329     }
330 
331     std::ofstream imageTensorFile;
332     imageTensorFile.open(outputPath, std::ofstream::out);
333     if (imageTensorFile.is_open())
334     {
335         mapbox::util::apply_visitor(
336             [&imageTensorFile](auto&& imageData){ WriteImageTensorImpl(imageData,imageTensorFile); },
337             imageDataContainers[0]
338             );
339 
340         if (!imageTensorFile)
341         {
342             ARMNN_LOG(fatal) << "Failed to write to output file" << outputPath;
343             imageTensorFile.close();
344             return -1;
345         }
346         imageTensorFile.close();
347     }
348     else
349     {
350         ARMNN_LOG(fatal) << "Failed to open output file" << outputPath;
351         return -1;
352     }
353 
354     return 0;
355 }
356