• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "armnn/LayerVisitorBase.hpp"
9 #include "RangeTracker.hpp"
10 #include "layers/DebugLayer.hpp"
11 
12 #include <armnn/INetwork.hpp>
13 #include <armnnQuantizer/INetworkQuantizer.hpp>
14 
15 namespace armnn
16 {
17 
18 /// Visitor class to establish min/max ranges based on the type of the layer
19 class DynamicQuantizationVisitor : public LayerVisitorBase<VisitorThrowingPolicy>
20 {
21 public:
22     DynamicQuantizationVisitor(RangeTracker& rangeTracker, Graph& graph);
23     ~DynamicQuantizationVisitor() = default;
24 
25     /// Functions to set the Range on a per-layer-type basis
26     void VisitAbsLayer(const IConnectableLayer* layer,
27                        const char* name = nullptr) override;
28 
29     void VisitAdditionLayer(const IConnectableLayer* layer,
30                             const char* name = nullptr) override;
31 
32     void VisitArgMinMaxLayer(const IConnectableLayer* layer,
33                              const ArgMinMaxDescriptor& desc,
34                              const char* name = nullptr) override;
35 
36     void VisitNormalizationLayer(const IConnectableLayer* layer,
37                                  const NormalizationDescriptor& desc,
38                                  const char* name = nullptr) override ;
39 
40     void VisitBatchNormalizationLayer(const IConnectableLayer* layer,
41                                       const BatchNormalizationDescriptor& desc,
42                                       const ConstTensor& mean,
43                                       const ConstTensor& variance,
44                                       const ConstTensor& beta,
45                                       const ConstTensor& gamma,
46                                       const char* name = nullptr) override;
47 
48     void VisitConvolution2dLayer(const IConnectableLayer* layer,
49                                  const Convolution2dDescriptor& convolution2dDescriptor,
50                                  const ConstTensor& weights,
51                                  const Optional<ConstTensor>& biases,
52                                  const char* name = nullptr) override;
53 
54     void VisitDepthwiseConvolution2dLayer(const IConnectableLayer* layer,
55                                           const DepthwiseConvolution2dDescriptor& desc,
56                                           const ConstTensor& weights,
57                                           const Optional<ConstTensor>& biases,
58                                           const char* name = nullptr) override;
59 
60     void VisitActivationLayer(const IConnectableLayer* layer,
61                               const ActivationDescriptor& activationDescriptor,
62                               const char* name = nullptr) override;
63 
64     void VisitFullyConnectedLayer(const IConnectableLayer *layer,
65                                   const FullyConnectedDescriptor& desc,
66                                   const ConstTensor& weights,
67                                   const Optional<ConstTensor>& biases,
68                                   const char *name) override;
69 
70     void VisitPermuteLayer(const IConnectableLayer* layer,
71                            const PermuteDescriptor& permuteDescriptor,
72                            const char* name) override;
73 
74     void VisitSpaceToBatchNdLayer(const IConnectableLayer* layer,
75                                   const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
76                                   const char* name = nullptr) override;
77 
78     void VisitPooling2dLayer(const IConnectableLayer* layer,
79                              const Pooling2dDescriptor& pooling2dDescriptor,
80                              const char* name) override;
81 
82     void VisitSoftmaxLayer(const IConnectableLayer* layer,
83                            const SoftmaxDescriptor& softmaxDescriptor,
84                            const char* name = nullptr) override;
85 
86     void VisitConcatLayer(const IConnectableLayer* layer,
87                           const ConcatDescriptor& originsDescriptor,
88                           const char* name = nullptr) override;
89 
90     void VisitConstantLayer(const IConnectableLayer* layer,
91                             const ConstTensor& input,
92                             const char* name = nullptr) override;
93 
94     void VisitReshapeLayer(const IConnectableLayer* layer,
95                            const ReshapeDescriptor& reshapeDescriptor,
96                            const char* name = nullptr) override;
97 
98     void VisitSplitterLayer(const IConnectableLayer* layer,
99                             const SplitterDescriptor& splitterDescriptor,
100                             const char* name = nullptr) override;
101 
102     void VisitResizeBilinearLayer(const IConnectableLayer* layer,
103                                   const ResizeBilinearDescriptor& resizeDesc,
104                                   const char* name = nullptr) override;
105 
106     void VisitStridedSliceLayer(const IConnectableLayer* layer,
107                                 const StridedSliceDescriptor& stridedSliceDescriptor,
108                                 const char* name = nullptr) override;
109 
110     void VisitBatchToSpaceNdLayer(const IConnectableLayer* layer,
111                                   const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
112                                   const char* name = nullptr) override;
113 
114     void VisitInputLayer(const IConnectableLayer* layer,
115                          LayerBindingId id,
116                          const char* name = nullptr) override;
117 
118     void VisitOutputLayer(const IConnectableLayer* layer,
119                           LayerBindingId id,
120                           const char* name = nullptr) override;
121 
122     void FinishVisit() override;
123     void VisitNonCalibratedLayers();
124 
125     const std::vector<armnn::LayerBindingId>& GetOutputLayers();
126 
127 private:
128     /// Set the range for an output slot on a layer
129     void SetRange(const IConnectableLayer* layer, unsigned int outputIdx, float min, float max);
130 
131     void ForwardParentParameters(const IConnectableLayer* layer);
132 
133     /// Mapping from a layer Guid to an array of ranges for outputs
134     RangeTracker& m_RangeTracker;
135 
136     Graph& m_Graph;
137 
138     std::vector<const IConnectableLayer*> m_LayersToCalibrate;
139     std::vector<const IConnectableLayer*> m_LayersNotToCalibrate;
140     std::vector<DebugLayer*> m_DebugLayers;
141 
142     std::vector<armnn::LayerBindingId> m_OutputLayers;
143 
144     void AddToCalibratedLayers(const IConnectableLayer* layer);
145     void AddToNonCalibratedLayers(const IConnectableLayer* layer);
146     void RemoveDebugLayers();
147 };
148 
149 } //namespace armnn
150