1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "armnn/LayerVisitorBase.hpp" 9 #include "RangeTracker.hpp" 10 #include "layers/DebugLayer.hpp" 11 12 #include <armnn/INetwork.hpp> 13 #include <armnnQuantizer/INetworkQuantizer.hpp> 14 15 namespace armnn 16 { 17 18 /// Visitor class to establish min/max ranges based on the type of the layer 19 class DynamicQuantizationVisitor : public LayerVisitorBase<VisitorThrowingPolicy> 20 { 21 public: 22 DynamicQuantizationVisitor(RangeTracker& rangeTracker, Graph& graph); 23 ~DynamicQuantizationVisitor() = default; 24 25 /// Functions to set the Range on a per-layer-type basis 26 void VisitAbsLayer(const IConnectableLayer* layer, 27 const char* name = nullptr) override; 28 29 void VisitAdditionLayer(const IConnectableLayer* layer, 30 const char* name = nullptr) override; 31 32 void VisitArgMinMaxLayer(const IConnectableLayer* layer, 33 const ArgMinMaxDescriptor& desc, 34 const char* name = nullptr) override; 35 36 void VisitNormalizationLayer(const IConnectableLayer* layer, 37 const NormalizationDescriptor& desc, 38 const char* name = nullptr) override ; 39 40 void VisitBatchNormalizationLayer(const IConnectableLayer* layer, 41 const BatchNormalizationDescriptor& desc, 42 const ConstTensor& mean, 43 const ConstTensor& variance, 44 const ConstTensor& beta, 45 const ConstTensor& gamma, 46 const char* name = nullptr) override; 47 48 void VisitConvolution2dLayer(const IConnectableLayer* layer, 49 const Convolution2dDescriptor& convolution2dDescriptor, 50 const ConstTensor& weights, 51 const Optional<ConstTensor>& biases, 52 const char* name = nullptr) override; 53 54 void VisitDepthwiseConvolution2dLayer(const IConnectableLayer* layer, 55 const DepthwiseConvolution2dDescriptor& desc, 56 const ConstTensor& weights, 57 const Optional<ConstTensor>& biases, 58 const char* name = nullptr) override; 59 60 void VisitActivationLayer(const IConnectableLayer* layer, 61 const ActivationDescriptor& activationDescriptor, 62 const char* name = nullptr) override; 63 64 void VisitFullyConnectedLayer(const IConnectableLayer *layer, 65 const FullyConnectedDescriptor& desc, 66 const ConstTensor& weights, 67 const Optional<ConstTensor>& biases, 68 const char *name) override; 69 70 void VisitPermuteLayer(const IConnectableLayer* layer, 71 const PermuteDescriptor& permuteDescriptor, 72 const char* name) override; 73 74 void VisitSpaceToBatchNdLayer(const IConnectableLayer* layer, 75 const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor, 76 const char* name = nullptr) override; 77 78 void VisitPooling2dLayer(const IConnectableLayer* layer, 79 const Pooling2dDescriptor& pooling2dDescriptor, 80 const char* name) override; 81 82 void VisitSoftmaxLayer(const IConnectableLayer* layer, 83 const SoftmaxDescriptor& softmaxDescriptor, 84 const char* name = nullptr) override; 85 86 void VisitConcatLayer(const IConnectableLayer* layer, 87 const ConcatDescriptor& originsDescriptor, 88 const char* name = nullptr) override; 89 90 void VisitConstantLayer(const IConnectableLayer* layer, 91 const ConstTensor& input, 92 const char* name = nullptr) override; 93 94 void VisitReshapeLayer(const IConnectableLayer* layer, 95 const ReshapeDescriptor& reshapeDescriptor, 96 const char* name = nullptr) override; 97 98 void VisitSplitterLayer(const IConnectableLayer* layer, 99 const SplitterDescriptor& splitterDescriptor, 100 const char* name = nullptr) override; 101 102 void VisitResizeBilinearLayer(const IConnectableLayer* layer, 103 const ResizeBilinearDescriptor& resizeDesc, 104 const char* name = nullptr) override; 105 106 void VisitStridedSliceLayer(const IConnectableLayer* layer, 107 const StridedSliceDescriptor& stridedSliceDescriptor, 108 const char* name = nullptr) override; 109 110 void VisitBatchToSpaceNdLayer(const IConnectableLayer* layer, 111 const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor, 112 const char* name = nullptr) override; 113 114 void VisitInputLayer(const IConnectableLayer* layer, 115 LayerBindingId id, 116 const char* name = nullptr) override; 117 118 void VisitOutputLayer(const IConnectableLayer* layer, 119 LayerBindingId id, 120 const char* name = nullptr) override; 121 122 void FinishVisit() override; 123 void VisitNonCalibratedLayers(); 124 125 const std::vector<armnn::LayerBindingId>& GetOutputLayers(); 126 127 private: 128 /// Set the range for an output slot on a layer 129 void SetRange(const IConnectableLayer* layer, unsigned int outputIdx, float min, float max); 130 131 void ForwardParentParameters(const IConnectableLayer* layer); 132 133 /// Mapping from a layer Guid to an array of ranges for outputs 134 RangeTracker& m_RangeTracker; 135 136 Graph& m_Graph; 137 138 std::vector<const IConnectableLayer*> m_LayersToCalibrate; 139 std::vector<const IConnectableLayer*> m_LayersNotToCalibrate; 140 std::vector<DebugLayer*> m_DebugLayers; 141 142 std::vector<armnn::LayerBindingId> m_OutputLayers; 143 144 void AddToCalibratedLayers(const IConnectableLayer* layer); 145 void AddToNonCalibratedLayers(const IConnectableLayer* layer); 146 void RemoveDebugLayers(); 147 }; 148 149 } //namespace armnn 150