1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "CommandLineProcessor.hpp" 7 #include <armnnDeserializer/IDeserializer.hpp> 8 #include <armnnQuantizer/INetworkQuantizer.hpp> 9 #include <armnnSerializer/ISerializer.hpp> 10 #include "QuantizationDataSet.hpp" 11 #include "QuantizationInput.hpp" 12 13 #include <algorithm> 14 #include <fstream> 15 #include <iostream> 16 main(int argc,char * argv[])17int main(int argc, char* argv[]) 18 { 19 armnnQuantizer::CommandLineProcessor cmdline; 20 if (!cmdline.ProcessCommandLine(argc, argv)) 21 { 22 return -1; 23 } 24 armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create(); 25 std::ifstream inputFileStream(cmdline.GetInputFileName(), std::ios::binary); 26 std::vector<std::uint8_t> binaryContent; 27 while (inputFileStream) 28 { 29 char c; 30 inputFileStream.get(c); 31 if (inputFileStream) 32 { 33 binaryContent.push_back(static_cast<std::uint8_t>(c)); 34 } 35 } 36 inputFileStream.close(); 37 38 armnn::QuantizerOptions quantizerOptions; 39 40 if (cmdline.GetQuantizationScheme() == "QAsymmS8") 41 { 42 quantizerOptions.m_ActivationFormat = armnn::DataType::QAsymmS8; 43 } 44 else if (cmdline.GetQuantizationScheme() == "QSymmS16") 45 { 46 quantizerOptions.m_ActivationFormat = armnn::DataType::QSymmS16; 47 } 48 else 49 { 50 quantizerOptions.m_ActivationFormat = armnn::DataType::QAsymmU8; 51 } 52 53 quantizerOptions.m_PreserveType = cmdline.HasPreservedDataType(); 54 55 armnn::INetworkPtr network = parser->CreateNetworkFromBinary(binaryContent); 56 armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get(), quantizerOptions); 57 58 if (cmdline.HasQuantizationData()) 59 { 60 armnnQuantizer::QuantizationDataSet dataSet = cmdline.GetQuantizationDataSet(); 61 if (!dataSet.IsEmpty()) 62 { 63 // Get the Input Tensor Infos 64 armnnQuantizer::InputLayerVisitor inputLayerVisitor; 65 network->Accept(inputLayerVisitor); 66 67 for (armnnQuantizer::QuantizationInput quantizationInput : dataSet) 68 { 69 armnn::InputTensors inputTensors; 70 std::vector<std::vector<float>> inputData(quantizationInput.GetNumberOfInputs()); 71 std::vector<armnn::LayerBindingId> layerBindingIds = quantizationInput.GetLayerBindingIds(); 72 unsigned int count = 0; 73 for (armnn::LayerBindingId layerBindingId : quantizationInput.GetLayerBindingIds()) 74 { 75 armnn::TensorInfo tensorInfo = inputLayerVisitor.GetTensorInfo(layerBindingId); 76 inputData[count] = quantizationInput.GetDataForEntry(layerBindingId); 77 armnn::ConstTensor inputTensor(tensorInfo, inputData[count].data()); 78 inputTensors.push_back(std::make_pair(layerBindingId, inputTensor)); 79 count++; 80 } 81 quantizer->Refine(inputTensors); 82 } 83 } 84 } 85 86 armnn::INetworkPtr quantizedNetwork = quantizer->ExportNetwork(); 87 armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create(); 88 serializer->Serialize(*quantizedNetwork); 89 90 std::string output(cmdline.GetOutputDirectoryName()); 91 output.append(cmdline.GetOutputFileName()); 92 std::ofstream outputFileStream; 93 outputFileStream.open(output); 94 serializer->SaveSerializedToStream(outputFileStream); 95 outputFileStream.flush(); 96 outputFileStream.close(); 97 98 return 0; 99 }