1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/INetwork.hpp> 9 #include <armnnQuantizer/INetworkQuantizer.hpp> 10 #include <armnn/IRuntime.hpp> 11 #include <armnn/Types.hpp> 12 #include <armnn/Optional.hpp> 13 14 #include "DynamicQuantizationVisitor.hpp" 15 #include "RangeTracker.hpp" 16 17 namespace armnn 18 { 19 20 class NetworkQuantizer : public INetworkQuantizer 21 { 22 public: NetworkQuantizer(INetwork * inputNetwork,const QuantizerOptions & options)23 NetworkQuantizer(INetwork* inputNetwork, const QuantizerOptions& options) 24 : m_InputNetwork(inputNetwork), 25 m_NetworkId(0), 26 m_Runtime(nullptr, &IRuntime::Destroy), 27 m_RefineCount(0), 28 m_Options(options) {} 29 30 void OverrideInputRange(LayerBindingId layerId, float min, float max) override; 31 void Refine(const InputTensors& inputTensors) override; 32 33 // Required for testing? Need some way to get min/max in RangeTracker (m_Ranges) GetMinMaxRange(LayerGuid guid,unsigned int idx)34 std::pair<float, float> GetMinMaxRange(LayerGuid guid, unsigned int idx) { return m_Ranges.GetRange(guid, idx); } 35 INetworkPtr ExportNetwork() override; 36 37 private: 38 /// Original input network to quantize 39 INetwork* m_InputNetwork; 40 41 NetworkId m_NetworkId; 42 43 // if we are run in dynamic mode this unique pointer will hold 44 // the runtime between invocations of the Refine method. 45 IRuntimePtr m_Runtime; 46 47 Optional<DynamicQuantizationVisitor> m_DynamicQuantizationVisitor; 48 49 // counts the number of times refine is called 50 unsigned int m_RefineCount; 51 52 /// Mapping from Guid to an array of ranges for outputs 53 RangeTracker m_Ranges; 54 55 /// Options for the NetworkQuantizer 56 QuantizerOptions m_Options; 57 58 std::pair<float, float> FindMinMax(ITensorHandle* tensorHandle); 59 }; 60 61 } //namespace armnn 62