• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "CommandLineProcessor.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 #include <armnnQuantizer/INetworkQuantizer.hpp>
9 #include <armnnSerializer/ISerializer.hpp>
10 #include "QuantizationDataSet.hpp"
11 #include "QuantizationInput.hpp"
12 
13 #include <algorithm>
14 #include <fstream>
15 #include <iostream>
16 
main(int argc,char * argv[])17 int main(int argc, char* argv[])
18 {
19     armnnQuantizer::CommandLineProcessor cmdline;
20     if (!cmdline.ProcessCommandLine(argc, argv))
21     {
22         return -1;
23     }
24     armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create();
25     std::ifstream inputFileStream(cmdline.GetInputFileName(), std::ios::binary);
26     std::vector<std::uint8_t> binaryContent;
27     while (inputFileStream)
28     {
29         char c;
30         inputFileStream.get(c);
31         if (inputFileStream)
32         {
33             binaryContent.push_back(static_cast<std::uint8_t>(c));
34         }
35     }
36     inputFileStream.close();
37 
38     armnn::QuantizerOptions quantizerOptions;
39 
40     if (cmdline.GetQuantizationScheme() == "QAsymmS8")
41     {
42         quantizerOptions.m_ActivationFormat = armnn::DataType::QAsymmS8;
43     }
44     else if (cmdline.GetQuantizationScheme() == "QSymmS16")
45     {
46         quantizerOptions.m_ActivationFormat = armnn::DataType::QSymmS16;
47     }
48     else
49     {
50         quantizerOptions.m_ActivationFormat = armnn::DataType::QAsymmU8;
51     }
52 
53     quantizerOptions.m_PreserveType = cmdline.HasPreservedDataType();
54 
55     armnn::INetworkPtr network = parser->CreateNetworkFromBinary(binaryContent);
56     armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get(), quantizerOptions);
57 
58     if (cmdline.HasQuantizationData())
59     {
60         armnnQuantizer::QuantizationDataSet dataSet = cmdline.GetQuantizationDataSet();
61         if (!dataSet.IsEmpty())
62         {
63             // Get the Input Tensor Infos
64             armnnQuantizer::InputLayerVisitor inputLayerVisitor;
65             network->Accept(inputLayerVisitor);
66 
67             for (armnnQuantizer::QuantizationInput quantizationInput : dataSet)
68             {
69                 armnn::InputTensors inputTensors;
70                 std::vector<std::vector<float>> inputData(quantizationInput.GetNumberOfInputs());
71                 std::vector<armnn::LayerBindingId> layerBindingIds = quantizationInput.GetLayerBindingIds();
72                 unsigned int count = 0;
73                 for (armnn::LayerBindingId layerBindingId : quantizationInput.GetLayerBindingIds())
74                 {
75                     armnn::TensorInfo tensorInfo = inputLayerVisitor.GetTensorInfo(layerBindingId);
76                     inputData[count] = quantizationInput.GetDataForEntry(layerBindingId);
77                     armnn::ConstTensor inputTensor(tensorInfo, inputData[count].data());
78                     inputTensors.push_back(std::make_pair(layerBindingId, inputTensor));
79                     count++;
80                 }
81                 quantizer->Refine(inputTensors);
82             }
83         }
84     }
85 
86     armnn::INetworkPtr quantizedNetwork = quantizer->ExportNetwork();
87     armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
88     serializer->Serialize(*quantizedNetwork);
89 
90     std::string output(cmdline.GetOutputDirectoryName());
91     output.append(cmdline.GetOutputFileName());
92     std::ofstream outputFileStream;
93     outputFileStream.open(output);
94     serializer->SaveSerializedToStream(outputFileStream);
95     outputFileStream.flush();
96     outputFileStream.close();
97 
98     return 0;
99 }