• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "armnn/LayerVisitorBase.hpp"
9 #include "StaticRangeVisitor.hpp"
10 #include "NetworkQuantizationScheme.hpp"
11 
12 #include <armnn/INetwork.hpp>
13 #include <armnn/Types.hpp>
14 #include <armnnQuantizer/INetworkQuantizer.hpp>
15 
16 #include <unordered_map>
17 
18 namespace armnn
19 {
20 
21 // Forward declaration
22 class StaticRangeVisitor;
23 
24 /// Visitor object for quantizing layers in a network
25 class QuantizerVisitor : public LayerVisitorBase<VisitorThrowingPolicy>
26 {
27 public:
28     QuantizerVisitor(const RangeTracker& rangeTracker,
29                      const IQuantizationScheme* quantizationScheme,
30                      bool preserveType = false);
31 
32     ~QuantizerVisitor() = default;
33 
34     /// Functions to quantize the individual layers, overridden from ILayerVisitor
35     ARMNN_DEPRECATED_MSG("Use VisitElementwiseUnaryLayer instead")
36     void VisitAbsLayer(const IConnectableLayer* layer, const char* name = nullptr) override;
37 
38     void VisitActivationLayer(const IConnectableLayer* layer,
39                               const ActivationDescriptor& activationDescriptor,
40                               const char* name = nullptr) override;
41 
42     void VisitAdditionLayer(const IConnectableLayer* layer, const char* name = nullptr) override;
43 
44     void VisitArgMinMaxLayer(const IConnectableLayer* layer,
45                              const ArgMinMaxDescriptor& argMinMaxDescriptor,
46                              const char* name = nullptr) override;
47 
48     void VisitBatchNormalizationLayer(const IConnectableLayer* layer,
49                                       const BatchNormalizationDescriptor& desc,
50                                       const ConstTensor& mean,
51                                       const ConstTensor& variance,
52                                       const ConstTensor& beta,
53                                       const ConstTensor& gamma,
54                                       const char* name = nullptr) override;
55 
56     void VisitBatchToSpaceNdLayer(const IConnectableLayer* layer,
57                                   const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
58                                   const char* name = nullptr) override;
59 
60     void VisitComparisonLayer(const IConnectableLayer* layer,
61                               const ComparisonDescriptor& comparisonDescriptor,
62                               const char* name = nullptr) override;
63 
64     void VisitConcatLayer(const IConnectableLayer* layer,
65                           const OriginsDescriptor& originsDescriptor,
66                           const char* name = nullptr) override;
67 
68     void VisitConstantLayer(const IConnectableLayer* layer,
69                             const ConstTensor& input,
70                             const char* name = nullptr) override;
71 
72     void VisitConvolution2dLayer(const IConnectableLayer* layer,
73                                  const Convolution2dDescriptor& convolution2dDescriptor,
74                                  const ConstTensor& weights,
75                                  const Optional<ConstTensor>& biases,
76                                  const char* name = nullptr) override;
77 
78     void VisitDepthToSpaceLayer(const IConnectableLayer* layer,
79                                 const DepthToSpaceDescriptor& depthToSpaceDescriptor,
80                                 const char* name = nullptr) override;
81 
82     void VisitDepthwiseConvolution2dLayer(const IConnectableLayer* layer,
83                                           const DepthwiseConvolution2dDescriptor& desc,
84                                           const ConstTensor& weights,
85                                           const Optional<ConstTensor>& biases,
86                                           const char* name = nullptr) override;
87 
88     void VisitElementwiseUnaryLayer(const IConnectableLayer* layer,
89                                     const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor,
90                                     const char* name = nullptr) override;
91 
92     void VisitFillLayer(const IConnectableLayer* layer,
93                         const FillDescriptor& desc,
94                         const char* name) override;
95 
96     void VisitFullyConnectedLayer(const IConnectableLayer *layer,
97                                   const FullyConnectedDescriptor& desc,
98                                   const ConstTensor& weights,
99                                   const Optional<ConstTensor>& biases,
100                                   const char *name = nullptr)  override;
101 
102     void VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr) override;
103 
104     void VisitInstanceNormalizationLayer(const IConnectableLayer* layer,
105                                          const InstanceNormalizationDescriptor& instanceNormalizationDescriptor,
106                                          const char* name = nullptr) override;
107 
108     void VisitLogSoftmaxLayer(const IConnectableLayer* layer,
109                               const LogSoftmaxDescriptor& logSoftmaxDescriptor,
110                               const char* name = nullptr) override;
111 
112     void VisitMeanLayer(const IConnectableLayer* layer,
113                         const MeanDescriptor& meanDescriptor,
114                         const char* name = nullptr) override;
115 
116     void VisitMultiplicationLayer(const IConnectableLayer* layer,
117                                   const char* name = nullptr) override;
118 
119     void VisitNormalizationLayer(const IConnectableLayer* layer,
120                                  const NormalizationDescriptor& normalizationDescriptor,
121                                  const char* name = nullptr) override;
122 
123     void VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr)  override;
124 
125     void VisitPadLayer(const IConnectableLayer*,
126                        const PadDescriptor&,
127                        const char* name = nullptr) override;
128 
129     void VisitPermuteLayer(const IConnectableLayer* layer,
130                            const PermuteDescriptor& permuteDescriptor,
131                            const char* name = nullptr) override;
132 
133     void VisitPooling2dLayer(const IConnectableLayer* layer,
134                              const Pooling2dDescriptor& pooling2dDescriptor,
135                              const char* name = nullptr) override;
136 
137     void VisitPreluLayer(const IConnectableLayer* layer,
138                          const char* name = nullptr) override;
139 
140     void VisitReshapeLayer(const IConnectableLayer* layer,
141                            const ReshapeDescriptor& reshapeDescriptor,
142                            const char* name = nullptr) override;
143 
144     void VisitResizeLayer(const IConnectableLayer* layer,
145                           const ResizeDescriptor& resizeDescriptor,
146                           const char* name = nullptr) override;
147 
148     ARMNN_DEPRECATED_MSG("Use VisitResizeLayer instead")
149     void VisitResizeBilinearLayer(const IConnectableLayer* layer,
150                                   const ResizeBilinearDescriptor& resizeDesc,
151                                   const char* name = nullptr) override;
152 
153     ARMNN_DEPRECATED_MSG("Use VisitElementwiseUnaryLayer instead")
154     void VisitRsqrtLayer(const IConnectableLayer*,
155                          const char* name = nullptr) override;
156 
157     void VisitSliceLayer(const IConnectableLayer* layer,
158                          const SliceDescriptor& sliceDescriptor,
159                          const char* name = nullptr) override;
160 
161     void VisitSoftmaxLayer(const IConnectableLayer* layer,
162                            const SoftmaxDescriptor& softmaxDescriptor,
163                            const char* name = nullptr) override;
164 
165     void VisitSpaceToBatchNdLayer(const IConnectableLayer* layer,
166                                   const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
167                                   const char* name = nullptr) override;
168 
169     void VisitSpaceToDepthLayer(const IConnectableLayer* layer,
170                                 const SpaceToDepthDescriptor& spaceToDepthDescriptor,
171                                 const char* name = nullptr) override;
172 
173     void VisitSplitterLayer(const IConnectableLayer* layer,
174                             const SplitterDescriptor& splitterDescriptor,
175                             const char* name = nullptr) override;
176 
177     void VisitStackLayer(const IConnectableLayer* layer,
178                          const StackDescriptor& stackDescriptor,
179                          const char* name = nullptr) override;
180 
181     void VisitStridedSliceLayer(const IConnectableLayer* layer,
182                                 const StridedSliceDescriptor& stridedSliceDescriptor,
183                                 const char* name = nullptr) override;
184 
185     void VisitSubtractionLayer(const IConnectableLayer* layer,
186                                const char* name = nullptr) override;
187 
188     void VisitTransposeConvolution2dLayer(const IConnectableLayer* layer,
189                                           const TransposeConvolution2dDescriptor& descriptor,
190                                           const ConstTensor& weights,
191                                           const Optional<ConstTensor>& biases,
192                                           const char* name = nullptr) override;
193 
194     void VisitTransposeLayer(const IConnectableLayer* layer,
195                              const TransposeDescriptor& descriptor,
196                              const char* name = nullptr) override;
197 
198     /// Extract the quantized network
RetrieveFinalNetwork()199     INetworkPtr RetrieveFinalNetwork() { return std::move(m_QuantizedNetwork); }
200 
201 private:
202     /// Connects the layer to preceeding layers and sets the quantization parameters based on recorded ranges
203     void SetQuantizedInputConnections(const IConnectableLayer* srcLayer, IConnectableLayer* quantizedLayer);
204 
205     /// Record the guids so we can easily find the layers later
206     void RecordLayer(const IConnectableLayer* srcLayer, IConnectableLayer* qLayer);
207 
208     /// Sets the bias quantization scale based on input and weight scales
209     ConstTensor CreateQuantizedBias(const IConnectableLayer* srcLayer,
210                                     const ConstTensor& weights,
211                                     const Optional<ConstTensor>& biases,
212                                     std::vector<int32_t>& weightsBacking);
213 
214     /// Reference to the static range visitor used to retrieve the quantization ranges
215     const RangeTracker& m_Ranges;
216 
217     /// Quantized version of the model we are building up
218     INetworkPtr m_QuantizedNetwork;
219 
220     /// Mapping from input network guids to quantized network guids
221     std::unordered_map<LayerGuid, LayerGuid> m_OriginalToQuantizedGuidMap;
222 
223     /// Mapping from guid to layer in quantized network
224     std::unordered_map<LayerGuid, IConnectableLayer*> m_QuantizedGuidToLayerMap;
225 
226     const IQuantizationScheme* m_QuantizationScheme;
227 
228     const bool m_PreserveType;
229 };
230 
231 } //namespace armnn
232