1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "armnn/LayerVisitorBase.hpp" 9 #include "StaticRangeVisitor.hpp" 10 #include "NetworkQuantizationScheme.hpp" 11 12 #include <armnn/INetwork.hpp> 13 #include <armnn/Types.hpp> 14 #include <armnnQuantizer/INetworkQuantizer.hpp> 15 16 #include <unordered_map> 17 18 namespace armnn 19 { 20 21 // Forward declaration 22 class StaticRangeVisitor; 23 24 /// Visitor object for quantizing layers in a network 25 class QuantizerVisitor : public LayerVisitorBase<VisitorThrowingPolicy> 26 { 27 public: 28 QuantizerVisitor(const RangeTracker& rangeTracker, 29 const IQuantizationScheme* quantizationScheme, 30 bool preserveType = false); 31 32 ~QuantizerVisitor() = default; 33 34 /// Functions to quantize the individual layers, overridden from ILayerVisitor 35 ARMNN_DEPRECATED_MSG("Use VisitElementwiseUnaryLayer instead") 36 void VisitAbsLayer(const IConnectableLayer* layer, const char* name = nullptr) override; 37 38 void VisitActivationLayer(const IConnectableLayer* layer, 39 const ActivationDescriptor& activationDescriptor, 40 const char* name = nullptr) override; 41 42 void VisitAdditionLayer(const IConnectableLayer* layer, const char* name = nullptr) override; 43 44 void VisitArgMinMaxLayer(const IConnectableLayer* layer, 45 const ArgMinMaxDescriptor& argMinMaxDescriptor, 46 const char* name = nullptr) override; 47 48 void VisitBatchNormalizationLayer(const IConnectableLayer* layer, 49 const BatchNormalizationDescriptor& desc, 50 const ConstTensor& mean, 51 const ConstTensor& variance, 52 const ConstTensor& beta, 53 const ConstTensor& gamma, 54 const char* name = nullptr) override; 55 56 void VisitBatchToSpaceNdLayer(const IConnectableLayer* layer, 57 const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor, 58 const char* name = nullptr) override; 59 60 void VisitComparisonLayer(const IConnectableLayer* layer, 61 const ComparisonDescriptor& comparisonDescriptor, 62 const char* name = nullptr) override; 63 64 void VisitConcatLayer(const IConnectableLayer* layer, 65 const OriginsDescriptor& originsDescriptor, 66 const char* name = nullptr) override; 67 68 void VisitConstantLayer(const IConnectableLayer* layer, 69 const ConstTensor& input, 70 const char* name = nullptr) override; 71 72 void VisitConvolution2dLayer(const IConnectableLayer* layer, 73 const Convolution2dDescriptor& convolution2dDescriptor, 74 const ConstTensor& weights, 75 const Optional<ConstTensor>& biases, 76 const char* name = nullptr) override; 77 78 void VisitDepthToSpaceLayer(const IConnectableLayer* layer, 79 const DepthToSpaceDescriptor& depthToSpaceDescriptor, 80 const char* name = nullptr) override; 81 82 void VisitDepthwiseConvolution2dLayer(const IConnectableLayer* layer, 83 const DepthwiseConvolution2dDescriptor& desc, 84 const ConstTensor& weights, 85 const Optional<ConstTensor>& biases, 86 const char* name = nullptr) override; 87 88 void VisitElementwiseUnaryLayer(const IConnectableLayer* layer, 89 const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor, 90 const char* name = nullptr) override; 91 92 void VisitFillLayer(const IConnectableLayer* layer, 93 const FillDescriptor& desc, 94 const char* name) override; 95 96 void VisitFullyConnectedLayer(const IConnectableLayer *layer, 97 const FullyConnectedDescriptor& desc, 98 const ConstTensor& weights, 99 const Optional<ConstTensor>& biases, 100 const char *name = nullptr) override; 101 102 void VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr) override; 103 104 void VisitInstanceNormalizationLayer(const IConnectableLayer* layer, 105 const InstanceNormalizationDescriptor& instanceNormalizationDescriptor, 106 const char* name = nullptr) override; 107 108 void VisitLogSoftmaxLayer(const IConnectableLayer* layer, 109 const LogSoftmaxDescriptor& logSoftmaxDescriptor, 110 const char* name = nullptr) override; 111 112 void VisitMeanLayer(const IConnectableLayer* layer, 113 const MeanDescriptor& meanDescriptor, 114 const char* name = nullptr) override; 115 116 void VisitMultiplicationLayer(const IConnectableLayer* layer, 117 const char* name = nullptr) override; 118 119 void VisitNormalizationLayer(const IConnectableLayer* layer, 120 const NormalizationDescriptor& normalizationDescriptor, 121 const char* name = nullptr) override; 122 123 void VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr) override; 124 125 void VisitPadLayer(const IConnectableLayer*, 126 const PadDescriptor&, 127 const char* name = nullptr) override; 128 129 void VisitPermuteLayer(const IConnectableLayer* layer, 130 const PermuteDescriptor& permuteDescriptor, 131 const char* name = nullptr) override; 132 133 void VisitPooling2dLayer(const IConnectableLayer* layer, 134 const Pooling2dDescriptor& pooling2dDescriptor, 135 const char* name = nullptr) override; 136 137 void VisitPreluLayer(const IConnectableLayer* layer, 138 const char* name = nullptr) override; 139 140 void VisitReshapeLayer(const IConnectableLayer* layer, 141 const ReshapeDescriptor& reshapeDescriptor, 142 const char* name = nullptr) override; 143 144 void VisitResizeLayer(const IConnectableLayer* layer, 145 const ResizeDescriptor& resizeDescriptor, 146 const char* name = nullptr) override; 147 148 ARMNN_DEPRECATED_MSG("Use VisitResizeLayer instead") 149 void VisitResizeBilinearLayer(const IConnectableLayer* layer, 150 const ResizeBilinearDescriptor& resizeDesc, 151 const char* name = nullptr) override; 152 153 ARMNN_DEPRECATED_MSG("Use VisitElementwiseUnaryLayer instead") 154 void VisitRsqrtLayer(const IConnectableLayer*, 155 const char* name = nullptr) override; 156 157 void VisitSliceLayer(const IConnectableLayer* layer, 158 const SliceDescriptor& sliceDescriptor, 159 const char* name = nullptr) override; 160 161 void VisitSoftmaxLayer(const IConnectableLayer* layer, 162 const SoftmaxDescriptor& softmaxDescriptor, 163 const char* name = nullptr) override; 164 165 void VisitSpaceToBatchNdLayer(const IConnectableLayer* layer, 166 const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor, 167 const char* name = nullptr) override; 168 169 void VisitSpaceToDepthLayer(const IConnectableLayer* layer, 170 const SpaceToDepthDescriptor& spaceToDepthDescriptor, 171 const char* name = nullptr) override; 172 173 void VisitSplitterLayer(const IConnectableLayer* layer, 174 const SplitterDescriptor& splitterDescriptor, 175 const char* name = nullptr) override; 176 177 void VisitStackLayer(const IConnectableLayer* layer, 178 const StackDescriptor& stackDescriptor, 179 const char* name = nullptr) override; 180 181 void VisitStridedSliceLayer(const IConnectableLayer* layer, 182 const StridedSliceDescriptor& stridedSliceDescriptor, 183 const char* name = nullptr) override; 184 185 void VisitSubtractionLayer(const IConnectableLayer* layer, 186 const char* name = nullptr) override; 187 188 void VisitTransposeConvolution2dLayer(const IConnectableLayer* layer, 189 const TransposeConvolution2dDescriptor& descriptor, 190 const ConstTensor& weights, 191 const Optional<ConstTensor>& biases, 192 const char* name = nullptr) override; 193 194 void VisitTransposeLayer(const IConnectableLayer* layer, 195 const TransposeDescriptor& descriptor, 196 const char* name = nullptr) override; 197 198 /// Extract the quantized network RetrieveFinalNetwork()199 INetworkPtr RetrieveFinalNetwork() { return std::move(m_QuantizedNetwork); } 200 201 private: 202 /// Connects the layer to preceeding layers and sets the quantization parameters based on recorded ranges 203 void SetQuantizedInputConnections(const IConnectableLayer* srcLayer, IConnectableLayer* quantizedLayer); 204 205 /// Record the guids so we can easily find the layers later 206 void RecordLayer(const IConnectableLayer* srcLayer, IConnectableLayer* qLayer); 207 208 /// Sets the bias quantization scale based on input and weight scales 209 ConstTensor CreateQuantizedBias(const IConnectableLayer* srcLayer, 210 const ConstTensor& weights, 211 const Optional<ConstTensor>& biases, 212 std::vector<int32_t>& weightsBacking); 213 214 /// Reference to the static range visitor used to retrieve the quantization ranges 215 const RangeTracker& m_Ranges; 216 217 /// Quantized version of the model we are building up 218 INetworkPtr m_QuantizedNetwork; 219 220 /// Mapping from input network guids to quantized network guids 221 std::unordered_map<LayerGuid, LayerGuid> m_OriginalToQuantizedGuidMap; 222 223 /// Mapping from guid to layer in quantized network 224 std::unordered_map<LayerGuid, IConnectableLayer*> m_QuantizedGuidToLayerMap; 225 226 const IQuantizationScheme* m_QuantizationScheme; 227 228 const bool m_PreserveType; 229 }; 230 231 } //namespace armnn 232