1 // 2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "Optimization.hpp" 9 #include <armnnUtils/DataLayoutIndexed.hpp> 10 #include <ResolveType.hpp> 11 12 namespace armnn 13 { 14 namespace optimizations 15 { 16 17 template <typename ConvLayer, armnn::DataType ArmnnType, 18 typename T = armnn::ResolveType<ArmnnType>> 19 class FuseBatchNorm 20 { 21 public: 22 /// Run for every exclusive connection between any base Convolution layer and a child BatchNorm layer for not 23 /// quantized layers. 24 /// The child will be removed, the base will be removed if it's left unconnected. A new Convolution layer will 25 /// be added, its weights and bias will be calculated using the weights and bias of the base Convolution layer 26 /// combined with the parameters of the child BatchNorm layer. Run(Graph & graph,InputSlot & connection) const27 void Run(Graph& graph, InputSlot& connection) const 28 { 29 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); 30 Layer& child = connection.GetOwningLayer(); 31 32 bool depthwise = (base.GetType() == LayerType::DepthwiseConvolution2d); 33 34 ARMNN_ASSERT(base.GetType() == LayerType::Convolution2d || depthwise); 35 ARMNN_ASSERT(child.GetType() == LayerType::BatchNormalization); 36 37 if (base.GetDataType() == ArmnnType && child.GetDataType() == ArmnnType) 38 { 39 OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot(); 40 auto convLayer = PolymorphicDowncast<ConvLayer*>(&base); 41 auto batchNormLayer = PolymorphicDowncast<BatchNormalizationLayer*>(&child); 42 43 // Read convolution and batch norm parameters 44 BatchNormalizationDescriptor batchNormDescriptor = batchNormLayer->GetParameters(); 45 auto epsilon = batchNormDescriptor.m_Eps; 46 IgnoreUnused(epsilon); 47 48 ConstTensor betaTensor(batchNormLayer->m_Beta->GetTensorInfo(), batchNormLayer->m_Beta->Map(true)); 49 ConstTensor gammaTensor(batchNormLayer->m_Gamma->GetTensorInfo(), batchNormLayer->m_Gamma->Map(true)); 50 ConstTensor meanTensor(batchNormLayer->m_Mean->GetTensorInfo(), batchNormLayer->m_Mean->Map(true)); 51 ConstTensor varTensor(batchNormLayer->m_Variance->GetTensorInfo(), batchNormLayer->m_Variance->Map(true)); 52 53 auto convDescriptor = convLayer->GetParameters(); 54 auto weightsInfo(convLayer->m_Weight->GetTensorInfo()); 55 ConstTensor weightsTensor(weightsInfo, convLayer->m_Weight->Map(true)); 56 57 armnnUtils::DataLayoutIndexed dataLayout(convDescriptor.m_DataLayout); 58 auto weightsShape = weightsInfo.GetShape(); 59 const unsigned int depthMultiplier = depthwise ? weightsShape[0] : 1; 60 const unsigned int inputChannels = depthwise ? weightsShape[1] : 61 weightsShape[dataLayout.GetChannelsIndex()]; 62 const unsigned int outputChannels = depthwise ? inputChannels * depthMultiplier : weightsShape[0]; 63 const unsigned int weightsHeight = depthwise ? weightsShape[2] : 64 weightsShape[dataLayout.GetHeightIndex()]; 65 const unsigned int weightsWidth = depthwise ? weightsShape[3] : 66 weightsShape[dataLayout.GetWidthIndex()]; 67 68 const auto* weightsBuffer = static_cast<const T*>(weightsTensor.GetMemoryArea()); 69 const auto* betaBuffer = static_cast<const T*>(betaTensor.GetMemoryArea()); 70 const auto* gammaBuffer = static_cast<const T*>(gammaTensor.GetMemoryArea()); 71 const auto* meanBuffer = static_cast<const T*>(meanTensor.GetMemoryArea()); 72 const auto* varBuffer = static_cast<const T*>(varTensor.GetMemoryArea()); 73 74 std::vector<T> weightsVector (weightsBuffer, weightsBuffer + weightsTensor.GetNumElements()); 75 std::vector<T> betaVector (betaBuffer, betaBuffer + betaTensor.GetNumElements()); 76 std::vector<T> gammaVector (gammaBuffer, gammaBuffer + gammaTensor.GetNumElements()); 77 std::vector<T> meanVector (meanBuffer, meanBuffer + meanTensor.GetNumElements()); 78 std::vector<T> varianceVector(varBuffer, varBuffer + varTensor.GetNumElements()); 79 80 // fusedWeights = ( gamma * weights ) / ( std - epsilon); 81 std::vector<T> fusedWeightsVector(weightsVector.size()); 82 unsigned int depthwiseMultiplierIdx = 0; 83 84 for (unsigned int cInput = 0; cInput < inputChannels; ++cInput) 85 { 86 for (unsigned int cOut = 0; cOut < outputChannels; ++cOut) 87 { 88 T mult = gammaVector[cOut] / static_cast<T>(sqrtf (varianceVector[cOut] + epsilon)); 89 90 if (depthwise) 91 { 92 cInput = cOut / depthMultiplier; 93 depthwiseMultiplierIdx = cOut % depthMultiplier; 94 } 95 96 for (unsigned int h = 0; h < weightsHeight; ++h) 97 { 98 for (unsigned int w = 0; w < weightsWidth; ++w) 99 { 100 unsigned int weightsIdx = 0; 101 102 if (depthwise) 103 { 104 weightsIdx = depthwiseMultiplierIdx * weightsWidth * weightsHeight * inputChannels + 105 cInput * weightsWidth * weightsHeight + 106 h * weightsWidth + 107 w; 108 } 109 else if (convDescriptor.m_DataLayout == DataLayout::NHWC) 110 { 111 weightsIdx = cOut * weightsHeight * weightsWidth * inputChannels + 112 h * weightsWidth * inputChannels + 113 w * inputChannels + 114 cInput; 115 } 116 else 117 { 118 weightsIdx = cOut * weightsWidth * weightsHeight * inputChannels + 119 cInput * weightsWidth * weightsHeight + 120 h * weightsWidth + 121 w; 122 } 123 fusedWeightsVector[weightsIdx] = mult * weightsVector[weightsIdx]; 124 } 125 } 126 } 127 } 128 ConstTensor fusedWeightsTensor(weightsInfo, fusedWeightsVector); 129 130 // fusedBias = (gamma * (bias - mean)) / (variance - epsilon) + beta; 131 std::vector<T> fusedBiasVector(outputChannels); 132 if (convDescriptor.m_BiasEnabled) 133 { 134 ARMNN_ASSERT_MSG(convLayer->m_Bias != nullptr, 135 "FuseBatchNorm: Bias data should not be null if bias is enabled."); 136 137 ConstTensor biasTensor(convLayer->m_Bias->GetTensorInfo(), convLayer->m_Bias->Map(true)); 138 const auto* biasBuffer = static_cast<const T*>(biasTensor.GetMemoryArea()); 139 std::vector<T> biasVector(biasBuffer, biasBuffer + biasTensor.GetNumElements()); 140 141 for (unsigned int cOut = 0; cOut < outputChannels; ++cOut) 142 { 143 fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) / 144 sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut]; 145 } 146 } 147 else 148 { 149 convDescriptor.m_BiasEnabled = true; 150 std::vector<T> biasVector(outputChannels, T(0)); 151 152 for (unsigned int cOut = 0; cOut < outputChannels; ++cOut) 153 { 154 fusedBiasVector[cOut] = ((gammaVector[cOut] * (biasVector[cOut] - meanVector[cOut])) / 155 sqrtf(varianceVector[cOut] + epsilon)) + betaVector[cOut]; 156 } 157 } 158 ConstTensor fusedBiasTensor(TensorInfo({outputChannels}, ArmnnType), fusedBiasVector); 159 160 // Insert the new convolution layer that has batch norm parameters fused into 161 const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") + base.GetName(); 162 auto& newConv2dLayer = *graph.InsertNewLayer<ConvLayer>(base.GetInputSlot(0), 163 convDescriptor, 164 name.c_str()); 165 newConv2dLayer.m_Weight = std::make_unique<ScopedCpuTensorHandle>(fusedWeightsTensor); 166 newConv2dLayer.m_Bias = std::make_unique<ScopedCpuTensorHandle>(ConstTensor(fusedBiasTensor)); 167 168 // Reconnects with original parent. 169 newConv2dLayer.GetOutputSlot().MoveAllConnections(*parentOut); 170 // Parent is now the new convolution2d layer. 171 parentOut = &newConv2dLayer.GetOutputSlot(); 172 173 // Moves connections in child output to parent layer. 174 // Child layer will be removed as it's left unconnected. 175 // Base layer will be removed if left unconnected. 176 child.GetOutputSlot().MoveAllConnections(*parentOut); 177 } 178 } 179 protected: 180 FuseBatchNorm() = default; 181 ~FuseBatchNorm() = default; 182 }; 183 184 using FuseBatchNormIntoConvolution2DFloat32 = 185 OptimizeForExclusiveConnection<Convolution2dLayer, 186 BatchNormalizationLayer, 187 FuseBatchNorm<Convolution2dLayer, armnn::DataType::Float32>>; 188 189 using FuseBatchNormIntoConvolution2DFloat16 = 190 OptimizeForExclusiveConnection<Convolution2dLayer, 191 BatchNormalizationLayer, 192 FuseBatchNorm<Convolution2dLayer, armnn::DataType::Float16>>; 193 194 using FuseBatchNormIntoDepthwiseConvolution2DFloat32 = 195 OptimizeForExclusiveConnection<DepthwiseConvolution2dLayer, 196 BatchNormalizationLayer, 197 FuseBatchNorm<DepthwiseConvolution2dLayer, armnn::DataType::Float32>>; 198 199 using FuseBatchNormIntoDepthwiseConvolution2DFloat16 = 200 OptimizeForExclusiveConnection<DepthwiseConvolution2dLayer, 201 BatchNormalizationLayer, 202 FuseBatchNorm<DepthwiseConvolution2dLayer, armnn::DataType::Float16>>; 203 204 } // namespace optimizations 205 } // namespace armnn