• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Optimization.hpp"
9 #include <armnnUtils/DataLayoutIndexed.hpp>
10 #include <ResolveType.hpp>
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
17 template <typename ConvLayer, armnn::DataType ArmnnType,
18           typename T = armnn::ResolveType<ArmnnType>>
19 class FuseBatchNorm
20 {
21 public:
22     /// Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for not
23     /// quantized layers.
24     /// The child will be removed, the base will be removed if it's left unconnected. A new Convolution layer will
25     /// be added, its weights and bias will be calculated using the weights and bias of the base Convolution layer
26     /// combined with the parameters of the child BatchNorm layer.
Run(Graph & graph,InputSlot & connection) const27     void Run(Graph& graph, InputSlot& connection) const
28     {
29         Layer& base  = connection.GetConnectedOutputSlot()->GetOwningLayer();
30         Layer& child = connection.GetOwningLayer();
31 
32         bool depthwise = (base.GetType() == LayerType::DepthwiseConvolution2d);
33 
34         ARMNN_ASSERT(base.GetType() == LayerType::Convolution2d || depthwise);
35         ARMNN_ASSERT(child.GetType() == LayerType::BatchNormalization);
36 
37         if (base.GetDataType() == ArmnnType && child.GetDataType() == ArmnnType)
38         {
39             OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
40             auto convLayer      = PolymorphicDowncast<ConvLayer*>(&base);
41             auto batchNormLayer = PolymorphicDowncast<BatchNormalizationLayer*>(&child);
42 
43             // Read convolution and batch norm parameters
44             BatchNormalizationDescriptor batchNormDescriptor = batchNormLayer->GetParameters();
45             auto epsilon = batchNormDescriptor.m_Eps;
46             IgnoreUnused(epsilon);
47 
48             ConstTensor betaTensor(batchNormLayer->m_Beta->GetTensorInfo(), batchNormLayer->m_Beta->Map(true));
49             ConstTensor gammaTensor(batchNormLayer->m_Gamma->GetTensorInfo(), batchNormLayer->m_Gamma->Map(true));
50             ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(true));
51             ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(true));
52 
53             auto convDescriptor = convLayer->GetParameters();
54             auto weightsInfo(convLayer->m_Weight->GetTensorInfo());
55             ConstTensor weightsTensor(weightsInfo, convLayer->m_Weight->Map(true));
56 
57             armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout);
58             auto weightsShape = weightsInfo.GetShape();
59             const unsigned int depthMultiplier = depthwise ? weightsShape[0] : 1;
60             const unsigned int inputChannels   = depthwise ? weightsShape[1] :
61                                                              weightsShape[dataLayout.GetChannelsIndex()];
62             const unsigned int outputChannels  = depthwise ? inputChannels * depthMultiplier : weightsShape[0];
63             const unsigned int weightsHeight   = depthwise ? weightsShape[2] :
64                                                              weightsShape[dataLayout.GetHeightIndex()];
65             const unsigned int weightsWidth    = depthwise ? weightsShape[3] :
66                                                              weightsShape[dataLayout.GetWidthIndex()];
67 
68             const auto* weightsBuffer = static_cast<const T*>(weightsTensor.GetMemoryArea());
69             const auto* betaBuffer    = static_cast<const T*>(betaTensor.GetMemoryArea());
70             const auto* gammaBuffer   = static_cast<const T*>(gammaTensor.GetMemoryArea());
71             const auto* meanBuffer    = static_cast<const T*>(meanTensor.GetMemoryArea());
72             const auto* varBuffer     = static_cast<const T*>(varTensor.GetMemoryArea());
73 
74             std::vector<T> weightsVector (weightsBuffer, weightsBuffer + weightsTensor.GetNumElements());
75             std::vector<T> betaVector    (betaBuffer, betaBuffer + betaTensor.GetNumElements());
76             std::vector<T> gammaVector   (gammaBuffer, gammaBuffer + gammaTensor.GetNumElements());
77             std::vector<T> meanVector    (meanBuffer, meanBuffer + meanTensor.GetNumElements());
78             std::vector<T> varianceVector(varBuffer, varBuffer + varTensor.GetNumElements());
79 
80             // fusedWeights = ( gamma * weights ) / ( std - epsilon);
81             std::vector<T> fusedWeightsVector(weightsVector.size());
82             unsigned int depthwiseMultiplierIdx = 0;
83 
84             for (unsigned int cInput = 0; cInput < inputChannels; ++cInput)
85             {
86                 for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
87                 {
88                     T mult = gammaVector[cOut] / static_cast<T>(sqrtf (varianceVector[cOut] + epsilon));
89 
90                     if (depthwise)
91                     {
92                         cInput = cOut / depthMultiplier;
93                         depthwiseMultiplierIdx = cOut % depthMultiplier;
94                     }
95 
96                     for (unsigned int h = 0; h < weightsHeight; ++h)
97                     {
98                         for (unsigned int w = 0; w < weightsWidth; ++w)
99                         {
100                             unsigned int weightsIdx = 0;
101 
102                             if (depthwise)
103                             {
104                                 weightsIdx = depthwiseMultiplierIdx * weightsWidth * weightsHeight * inputChannels +
105                                              cInput * weightsWidth * weightsHeight +
106                                              h * weightsWidth +
107                                              w;
108                             }
109                             else if (convDescriptor.m_DataLayout == DataLayout::NHWC)
110                             {
111                                 weightsIdx = cOut * weightsHeight * weightsWidth * inputChannels +
112                                              h * weightsWidth * inputChannels +
113                                              w * inputChannels +
114                                              cInput;
115                             }
116                             else
117                             {
118                                 weightsIdx = cOut * weightsWidth * weightsHeight * inputChannels +
119                                              cInput * weightsWidth * weightsHeight +
120                                              h * weightsWidth +
121                                              w;
122                             }
123                             fusedWeightsVector[weightsIdx] = mult * weightsVector[weightsIdx];
124                         }
125                     }
126                 }
127             }
128             ConstTensor fusedWeightsTensor(weightsInfo, fusedWeightsVector);
129 
130             //  fusedBias = (gamma * (bias - mean)) / (variance - epsilon) + beta;
131             std::vector<T> fusedBiasVector(outputChannels);
132             if (convDescriptor.m_BiasEnabled)
133             {
134                 ARMNN_ASSERT_MSG(convLayer->m_Bias != nullptr,
135                                  "FuseBatchNorm: Bias data should not be null if bias is enabled.");
136 
137                 ConstTensor biasTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(true));
138                 const auto* biasBuffer = static_cast<const T*>(biasTensor.GetMemoryArea());
139                 std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements());
140 
141                 for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
142                 {
143                     fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
144                                              sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
145                 }
146             }
147             else
148             {
149                 convDescriptor.m_BiasEnabled = true;
150                 std::vector<T> biasVector(outputChannels, T(0));
151 
152                 for (unsigned int cOut = 0; cOut < outputChannels; ++cOut)
153                 {
154                     fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) /
155                                              sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut];
156                 }
157             }
158             ConstTensor fusedBiasTensor(TensorInfo({outputChannels}, ArmnnType), fusedBiasVector);
159 
160             // Insert the new convolution layer that has batch norm parameters fused into
161             const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") + base.GetName();
162             auto& newConv2dLayer = *graph.InsertNewLayer<ConvLayer>(base.GetInputSlot(0),
163                                                                     convDescriptor,
164                                                                     name.c_str());
165             newConv2dLayer.m_Weight = std::make_unique<ScopedCpuTensorHandle>(fusedWeightsTensor);
166             newConv2dLayer.m_Bias = std::make_unique<ScopedCpuTensorHandle>(ConstTensor(fusedBiasTensor));
167 
168             // Reconnects with original parent.
169             newConv2dLayer.GetOutputSlot().MoveAllConnections(*parentOut);
170             // Parent is now the new convolution2d layer.
171             parentOut = &newConv2dLayer.GetOutputSlot();
172 
173             // Moves connections in child output to parent layer.
174             // Child layer will be removed as it's left unconnected.
175             // Base layer will be removed if left unconnected.
176             child.GetOutputSlot().MoveAllConnections(*parentOut);
177         }
178     }
179 protected:
180     FuseBatchNorm()  = default;
181     ~FuseBatchNorm() = default;
182 };
183 
184 using FuseBatchNormIntoConvolution2DFloat32 =
185         OptimizeForExclusiveConnection<Convolution2dLayer,
186                                        BatchNormalizationLayer,
187                                        FuseBatchNorm<Convolution2dLayer, armnn::DataType::Float32>>;
188 
189 using FuseBatchNormIntoConvolution2DFloat16 =
190         OptimizeForExclusiveConnection<Convolution2dLayer,
191                                        BatchNormalizationLayer,
192                                        FuseBatchNorm<Convolution2dLayer, armnn::DataType::Float16>>;
193 
194 using FuseBatchNormIntoDepthwiseConvolution2DFloat32 =
195         OptimizeForExclusiveConnection<DepthwiseConvolution2dLayer,
196                                        BatchNormalizationLayer,
197                                        FuseBatchNorm<DepthwiseConvolution2dLayer, armnn::DataType::Float32>>;
198 
199 using FuseBatchNormIntoDepthwiseConvolution2DFloat16 =
200         OptimizeForExclusiveConnection<DepthwiseConvolution2dLayer,
201                                        BatchNormalizationLayer,
202                                        FuseBatchNorm<DepthwiseConvolution2dLayer, armnn::DataType::Float16>>;
203 
204 } // namespace optimizations
205 } // namespace armnn