• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <map>
9 #include "QuantizationInput.hpp"
10 #include "armnn/LayerVisitorBase.hpp"
11 #include "armnn/Tensor.hpp"
12 
13 namespace armnnQuantizer
14 {
15 
16 /// QuantizationDataSet is a structure which is created after parsing a quantization CSV file.
17 /// It contains records of filenames which contain refinement data per pass ID for binding ID.
18 class QuantizationDataSet
19 {
20     using QuantizationInputs = std::vector<armnnQuantizer::QuantizationInput>;
21 public:
22 
23     using iterator = QuantizationInputs::iterator;
24     using const_iterator = QuantizationInputs::const_iterator;
25 
26     QuantizationDataSet();
27     QuantizationDataSet(std::string csvFilePath);
28     ~QuantizationDataSet();
IsEmpty() const29     bool IsEmpty() const {return m_QuantizationInputs.empty();}
30 
begin()31     iterator begin() { return m_QuantizationInputs.begin(); }
end()32     iterator end() { return m_QuantizationInputs.end(); }
begin() const33     const_iterator begin() const { return m_QuantizationInputs.begin(); }
end() const34     const_iterator end() const { return m_QuantizationInputs.end(); }
cbegin() const35     const_iterator cbegin() const { return m_QuantizationInputs.cbegin(); }
cend() const36     const_iterator cend() const { return m_QuantizationInputs.cend(); }
37 
38 private:
39     void ParseCsvFile();
40 
41     QuantizationInputs m_QuantizationInputs;
42     std::string m_CsvFilePath;
43 };
44 
45 /// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine.
46 class InputLayerVisitor : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
47 {
48 public:
49     void VisitInputLayer(const armnn::IConnectableLayer *layer, armnn::LayerBindingId id, const char* name);
50     armnn::TensorInfo GetTensorInfo(armnn::LayerBindingId);
51 private:
52     std::map<armnn::LayerBindingId, armnn::TensorInfo> m_TensorInfos;
53 };
54 
55 } // namespace armnnQuantizer