• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/INetwork.hpp>
9 #include <armnnQuantizer/INetworkQuantizer.hpp>
10 #include <armnn/IRuntime.hpp>
11 #include <armnn/Types.hpp>
12 #include <armnn/Optional.hpp>
13 
14 #include "DynamicQuantizationVisitor.hpp"
15 #include "RangeTracker.hpp"
16 
17 namespace armnn
18 {
19 
20 class NetworkQuantizer : public INetworkQuantizer
21 {
22 public:
NetworkQuantizer(INetwork * inputNetwork,const QuantizerOptions & options)23     NetworkQuantizer(INetwork* inputNetwork, const QuantizerOptions& options)
24     : m_InputNetwork(inputNetwork),
25       m_NetworkId(0),
26       m_Runtime(nullptr, &IRuntime::Destroy),
27       m_RefineCount(0),
28       m_Options(options) {}
29 
30     void OverrideInputRange(LayerBindingId layerId, float min, float max) override;
31     void Refine(const InputTensors& inputTensors) override;
32 
33     // Required for testing? Need some way to get min/max in RangeTracker (m_Ranges)
GetMinMaxRange(LayerGuid guid,unsigned int idx)34     std::pair<float, float> GetMinMaxRange(LayerGuid guid, unsigned int idx) { return m_Ranges.GetRange(guid, idx); }
35     INetworkPtr ExportNetwork() override;
36 
37 private:
38     /// Original input network to quantize
39     INetwork* m_InputNetwork;
40 
41     NetworkId m_NetworkId;
42 
43     // if we are run in dynamic mode this unique pointer will hold
44     // the runtime between invocations of the Refine method.
45     IRuntimePtr m_Runtime;
46 
47     Optional<DynamicQuantizationVisitor> m_DynamicQuantizationVisitor;
48 
49     // counts the number of times refine is called
50     unsigned int m_RefineCount;
51 
52     /// Mapping from Guid to an array of ranges for outputs
53     RangeTracker m_Ranges;
54 
55     /// Options for the NetworkQuantizer
56     QuantizerOptions m_Options;
57 
58     std::pair<float, float> FindMinMax(ITensorHandle* tensorHandle);
59 };
60 
61 } //namespace armnn
62