1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <map> 9 #include "QuantizationInput.hpp" 10 #include "armnn/LayerVisitorBase.hpp" 11 #include "armnn/Tensor.hpp" 12 13 namespace armnnQuantizer 14 { 15 16 /// QuantizationDataSet is a structure which is created after parsing a quantization CSV file. 17 /// It contains records of filenames which contain refinement data per pass ID for binding ID. 18 class QuantizationDataSet 19 { 20 using QuantizationInputs = std::vector<armnnQuantizer::QuantizationInput>; 21 public: 22 23 using iterator = QuantizationInputs::iterator; 24 using const_iterator = QuantizationInputs::const_iterator; 25 26 QuantizationDataSet(); 27 QuantizationDataSet(std::string csvFilePath); 28 ~QuantizationDataSet(); IsEmpty() const29 bool IsEmpty() const {return m_QuantizationInputs.empty();} 30 begin()31 iterator begin() { return m_QuantizationInputs.begin(); } end()32 iterator end() { return m_QuantizationInputs.end(); } begin() const33 const_iterator begin() const { return m_QuantizationInputs.begin(); } end() const34 const_iterator end() const { return m_QuantizationInputs.end(); } cbegin() const35 const_iterator cbegin() const { return m_QuantizationInputs.cbegin(); } cend() const36 const_iterator cend() const { return m_QuantizationInputs.cend(); } 37 38 private: 39 void ParseCsvFile(); 40 41 QuantizationInputs m_QuantizationInputs; 42 std::string m_CsvFilePath; 43 }; 44 45 /// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine. 46 class InputLayerVisitor : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> 47 { 48 public: 49 void VisitInputLayer(const armnn::IConnectableLayer *layer, armnn::LayerBindingId id, const char* name); 50 armnn::TensorInfo GetTensorInfo(armnn::LayerBindingId); 51 private: 52 std::map<armnn::LayerBindingId, armnn::TensorInfo> m_TensorInfos; 53 }; 54 55 } // namespace armnnQuantizer