• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020,2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Optimization.hpp"
9 #include <armnnUtils/DataLayoutIndexed.hpp>
10 #include <ResolveType.hpp>
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
17 template<typename ConvLayer, armnn::DataType ArmnnType,
18          typename T = armnn::ResolveType<ArmnnType>>
19 class FuseBatchNorm
20 {
21 public:
22     /// Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for not
23     /// quantized layers.
24     /// The child will be removed, the base will be removed if it's left unconnected. A new Convolution layer will
25     /// be added, its weights and bias will be calculated using the weights and bias of the base Convolution layer
26     /// combined with the parameters of the child BatchNorm layer.
Run(Graph & graph,InputSlot & connection) const27     void Run(Graph& graph, InputSlot& connection) const
28     {
29         Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
30         Layer& child = connection.GetOwningLayer();
31 
32         bool depthwise = (base.GetType() == LayerType::DepthwiseConvolution2d);
33 
34         ARMNN_ASSERT(base.GetType() == LayerType::Convolution2d || depthwise);
35         ARMNN_ASSERT(child.GetType() == LayerType::BatchNormalization);
36 
37         if (base.GetDataType() == ArmnnType && child.GetDataType() == ArmnnType)
38         {
39             OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
40             auto convLayer = PolymorphicDowncast<ConvLayer*>(&base);
41             auto batchNormLayer = PolymorphicDowncast<BatchNormalizationLayer*>(&child);
42 
43             // Read convolution and batch norm parameters
44             BatchNormalizationDescriptor batchNormDescriptor = batchNormLayer->GetParameters();
45             auto epsilon = batchNormDescriptor.m_Eps;
46             IgnoreUnused(epsilon);
47 
48             ConstTensor betaTensor(batchNormLayer->m_Beta->GetTensorInfo(), batchNormLayer->m_Beta->Map(true));
49             ConstTensor gammaTensor(batchNormLayer->m_Gamma->GetTensorInfo(), batchNormLayer->m_Gamma->Map(true));
50             ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(true));
51             ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(true));
52 
53             auto convDescriptor = convLayer->GetParameters();
54             ConstTensor weightsTensor;
55             ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[1].GetConnection() != nullptr,
56                              "FuseBatchNorm: Weight data should not be null.");
57 
58             ConstantLayer* weightLayer = PolymorphicDowncast<ConstantLayer*>(
59                                         &base.GetInputSlot(1).GetConnectedOutputSlot()->GetOwningLayer());
60 
61             weightsTensor = ConstTensor(weightLayer->m_LayerOutput->GetTensorInfo(),
62                                         weightLayer->m_LayerOutput->Map(true));
63 
64             armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout);
65             auto weightsShape = weightsTensor.GetInfo().GetShape();
66             const unsigned int inputChannels   = parentOut->GetTensorInfo().GetShape()[dataLayout.GetChannelsIndex()];
67             const unsigned int depthMultiplier = depthwise ? weightsShape[3] / inputChannels : 1;
68             const unsigned int outputChannels  = depthwise ? weightsShape[3] : weightsShape[0];
69             const unsigned int weightsHeight   = depthwise ? weightsShape[1] :
70                                                  weightsShape[dataLayout.GetHeightIndex()];
71             const unsigned int weightsWidth    = depthwise ? weightsShape[2] :
72                                                  weightsShape[dataLayout.GetWidthIndex()];
73 
74             const auto* weightsBuffer = static_cast<const T*>(weightsTensor.GetMemoryArea());
75             const auto* betaBuffer    = static_cast<const T*>(betaTensor.GetMemoryArea());
76             const auto* gammaBuffer   = static_cast<const T*>(gammaTensor.GetMemoryArea());
77             const auto* meanBuffer    = static_cast<const T*>(meanTensor.GetMemoryArea());
78             const auto* varBuffer     = static_cast<const T*>(varTensor.GetMemoryArea());
79 
80             std::vector<T> weightsVector (weightsBuffer, weightsBuffer + weightsTensor.GetNumElements());
81             std::vector<T> betaVector    (betaBuffer, betaBuffer + betaTensor.GetNumElements());
82             std::vector<T> gammaVector   (gammaBuffer, gammaBuffer + gammaTensor.GetNumElements());
83             std::vector<T> meanVector    (meanBuffer, meanBuffer + meanTensor.GetNumElements());
84             std::vector<T> varianceVector(varBuffer, varBuffer + varTensor.GetNumElements());
85 
86             // fusedWeights = ( gamma * weights ) / ( std - epsilon);
87             std::vector<T> fusedWeightsVector(weightsVector.size());
88 
89             for (unsigned int cInput = 0; cInput < inputChannels; ++cInput)
90             {
91                 for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
92                 {
93                     T mult = gammaVector[cOut] / static_cast<T>(sqrtf(varianceVector[cOut] + epsilon));
94 
95                     for (unsigned int h = 0; h < weightsHeight; ++h)
96                     {
97                         for (unsigned int w = 0; w < weightsWidth; ++w)
98                         {
99                             unsigned int weightsIdx = 0;
100 
101                             if (depthwise)
102                             {
103                                 cInput = cOut / depthMultiplier;
104                                 weightsIdx = w * outputChannels + cOut +
105                                              h * weightsWidth * outputChannels;
106                             }
107                             else if (convDescriptor.m_DataLayout == DataLayout::NHWC)
108                             {
109                                 weightsIdx = cOut * weightsHeight * weightsWidth * inputChannels +
110                                              h * weightsWidth * inputChannels +
111                                              w * inputChannels +
112                                              cInput;
113                             }
114                             else
115                             {
116                                 weightsIdx = cOut * weightsWidth * weightsHeight * inputChannels +
117                                              cInput * weightsWidth * weightsHeight +
118                                              h * weightsWidth +
119                                              w;
120                             }
121                             fusedWeightsVector[weightsIdx] = mult * weightsVector[weightsIdx];
122                         }
123                     }
124                 }
125             }
126             ConstTensor fusedWeightsTensor(weightsTensor.GetInfo(), fusedWeightsVector);
127 
128             //  fusedBias = (gamma * (bias - mean)) / (variance - epsilon) + beta;
129             std::vector<T> fusedBiasVector(outputChannels);
130             bool biasWasEnabledBeforeOpt = convDescriptor.m_BiasEnabled;
131             if (biasWasEnabledBeforeOpt)
132             {
133                 ConstTensor biasTensor;
134                 ARMNN_ASSERT_MSG(convLayer->GetInputSlots()[2].GetConnection() != nullptr,
135                                  "FuseBatchNorm: Bias data should not be null if bias is enabled.");
136 
137                 ConstantLayer* biasLayer = PolymorphicDowncast<ConstantLayer*>(
138                                                 &base.GetInputSlot(2).GetConnectedOutputSlot()->GetOwningLayer());
139 
140                 biasTensor = ConstTensor(biasLayer->m_LayerOutput->GetTensorInfo(),
141                                          biasLayer->m_LayerOutput->Map(true));
142 
143                 const auto* biasBuffer = static_cast<const T*>(biasTensor.GetMemoryArea());
144                 std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements());
145 
146                 for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
147                 {
148                     fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
149                                              sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
150                 }
151             }
152             else
153             {
154                 convDescriptor.m_BiasEnabled = true;
155                 std::vector<T> biasVector(outputChannels, T(0));
156 
157                 for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
158                 {
159                     fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
160                                              sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
161                 }
162             }
163             ConstTensor fusedBiasTensor(TensorInfo({outputChannels}, ArmnnType, 0.0f, 0, true), fusedBiasVector);
164 
165             // Insert the new convolution layer that has batch norm parameters fused into
166             const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") + base.GetName();
167             auto& newConv2dLayer = *graph.InsertNewLayer<ConvLayer>(base.GetInputSlot(0),
168                                                                     convDescriptor,
169                                                                     name.c_str());
170 
171             // Connect weights and bias from old to new Conv2d layer
172             // This optimization will always have 3 input slots on the Conv2d base layer
173             if (newConv2dLayer.GetNumInputSlots() > 1)
174             {
175                 // Remove old connection and connect to new layer2d
176                 weightLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(1));
177                 weightLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(1));
178                 weightLayer->m_LayerOutput = std::make_unique<ScopedTensorHandle>(fusedWeightsTensor);
179 
180                 // Move bias const layers as normal if it was enabled before the optimisation
181                 ConstantLayer* biasLayer;
182                 if (biasWasEnabledBeforeOpt)
183                 {
184                     biasLayer = PolymorphicDowncast<ConstantLayer*>(
185                         &base.GetInputSlot(2).GetConnectedOutputSlot()->GetOwningLayer());
186                     // Remove old connection and connect to new layer2d
187                     biasLayer->GetOutputSlot(0).Disconnect(base.GetInputSlot(2));
188                     biasLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(2));
189 
190                 }
191                 // Otherwise create a new bias layer and add to the new convolution2d
192                 else
193                 {
194                     // Add in bias constant layer
195                     biasLayer = graph.AddLayer<ConstantLayer>("Bias");
196                     biasLayer->GetOutputSlot(0).SetTensorInfo(fusedBiasTensor.GetInfo());
197                     biasLayer->GetOutputSlot(0).Connect(newConv2dLayer.GetInputSlot(2));
198                 }
199                 biasLayer->m_LayerOutput = std::make_unique<ScopedTensorHandle>(ConstTensor(fusedBiasTensor));
200             }
201 
202 
203             // Reconnects with original parent.
204             newConv2dLayer.GetOutputSlot().MoveAllConnections(*parentOut);
205             // Parent is now the new convolution2d layer.
206             parentOut = &newConv2dLayer.GetOutputSlot();
207 
208             // Moves connections in child output to parent layer.
209             // Child layer will be removed as it's left unconnected.
210             // Base layer will be removed if left unconnected.
211             child.GetOutputSlot().MoveAllConnections(*parentOut);
212         }
213     }
214 protected:
215     FuseBatchNorm()  = default;
216     ~FuseBatchNorm() = default;
217 };
218 
219 using FuseBatchNormIntoConvolution2DFloat32 =
220         OptimizeForExclusiveConnection<Convolution2dLayer,
221                                        BatchNormalizationLayer,
222                                        FuseBatchNorm<Convolution2dLayer, armnn::DataType::Float32>>;
223 
224 using FuseBatchNormIntoConvolution2DFloat16 =
225         OptimizeForExclusiveConnection<Convolution2dLayer,
226                                        BatchNormalizationLayer,
227                                        FuseBatchNorm<Convolution2dLayer, armnn::DataType::Float16>>;
228 
229 using FuseBatchNormIntoDepthwiseConvolution2DFloat32 =
230         OptimizeForExclusiveConnection<DepthwiseConvolution2dLayer,
231                                        BatchNormalizationLayer,
232                                        FuseBatchNorm<DepthwiseConvolution2dLayer, armnn::DataType::Float32>>;
233 
234 using FuseBatchNormIntoDepthwiseConvolution2DFloat16 =
235         OptimizeForExclusiveConnection<DepthwiseConvolution2dLayer,
236                                        BatchNormalizationLayer,
237                                        FuseBatchNorm<DepthwiseConvolution2dLayer, armnn::DataType::Float16>>;
238 
239 } // namespace optimizations
240 } // namespace armnn