• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ConvImpl.hpp"
7 
8 #include <armnn/utility/Assert.hpp>
9 
10 #include <cmath>
11 #include <limits>
12 
13 namespace armnn
14 {
15 
QuantizedMultiplierSmallerThanOne(float multiplier)16 QuantizedMultiplierSmallerThanOne::QuantizedMultiplierSmallerThanOne(float multiplier)
17 {
18     ARMNN_ASSERT(multiplier >= 0.0f && multiplier < 1.0f);
19     if (multiplier == 0.0f)
20     {
21         m_Multiplier = 0;
22         m_RightShift = 0;
23     }
24     else
25     {
26         const double q = std::frexp(multiplier, &m_RightShift);
27         m_RightShift = -m_RightShift;
28         int64_t qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
29         ARMNN_ASSERT(qFixed <= (1ll << 31));
30         if (qFixed == (1ll << 31))
31         {
32             qFixed /= 2;
33             --m_RightShift;
34         }
35         ARMNN_ASSERT(m_RightShift >= 0);
36         ARMNN_ASSERT(qFixed <= std::numeric_limits<int32_t>::max());
37         m_Multiplier = static_cast<int32_t>(qFixed);
38     }
39 }
40 
operator *(int32_t rhs) const41 int32_t QuantizedMultiplierSmallerThanOne::operator*(int32_t rhs) const
42 {
43     int32_t x = SaturatingRoundingDoublingHighMul(rhs, m_Multiplier);
44     return RoundingDivideByPOT(x, m_RightShift);
45 }
46 
SaturatingRoundingDoublingHighMul(int32_t a,int32_t b)47 int32_t QuantizedMultiplierSmallerThanOne::SaturatingRoundingDoublingHighMul(int32_t a, int32_t b)
48 {
49     // Check for overflow.
50     if (a == b && a == std::numeric_limits<int32_t>::min())
51     {
52         return std::numeric_limits<int32_t>::max();
53     }
54     int64_t a_64(a);
55     int64_t b_64(b);
56     int64_t ab_64 = a_64 * b_64;
57     int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
58     int32_t ab_x2_high32 = static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
59     return ab_x2_high32;
60 }
61 
RoundingDivideByPOT(int32_t x,int exponent)62 int32_t QuantizedMultiplierSmallerThanOne::RoundingDivideByPOT(int32_t x, int exponent)
63 {
64     ARMNN_ASSERT(exponent >= 0 && exponent <= 31);
65     int32_t mask = (1 << exponent) - 1;
66     int32_t remainder = x & mask;
67     int32_t threshold = (mask >> 1) + (x < 0 ? 1 : 0);
68     return (x >> exponent) + (remainder > threshold ? 1 : 0);
69 }
70 
Convolve(const TensorShape & rInputShape,Decoder<float> & rInputDecoder,const TensorShape & rOutputShape,Encoder<float> & rOutputEncoder,const TensorShape & rFilterShape,Decoder<float> & rFilterDecoder,bool biasEnabled,Decoder<float> * pBiasDecoder,DataLayout dataLayout,unsigned int paddingTop,unsigned int paddingLeft,unsigned int xStride,unsigned int yStride,unsigned int xDilation,unsigned int yDilation,bool depthwise)71 void Convolve(const TensorShape& rInputShape,
72               Decoder<float>& rInputDecoder,
73               const TensorShape& rOutputShape,
74               Encoder<float>& rOutputEncoder,
75               const TensorShape& rFilterShape,
76               Decoder<float>& rFilterDecoder,
77               bool biasEnabled,
78               Decoder<float>* pBiasDecoder,
79               DataLayout dataLayout,
80               unsigned int paddingTop,
81               unsigned int paddingLeft,
82               unsigned int xStride,
83               unsigned int yStride,
84               unsigned int xDilation,
85               unsigned int yDilation,
86               bool depthwise)
87 {
88     if (biasEnabled && !pBiasDecoder)
89     {
90         throw InvalidArgumentException("Bias is enabled but the bias data is invalid");
91     }
92     const armnnUtils::DataLayoutIndexed dataLayoutIndexed(dataLayout);
93 
94     const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
95     const unsigned int heightIndex   = dataLayoutIndexed.GetHeightIndex();
96     const unsigned int widthIndex    = dataLayoutIndexed.GetWidthIndex();
97 
98     const unsigned int depthMultiplier = depthwise ? rFilterShape[0] : 1;
99     const unsigned int inputChannels   = depthwise ? rFilterShape[1] : rFilterShape[channelsIndex];
100     const unsigned int outputChannels  = depthwise ? inputChannels * depthMultiplier : rFilterShape[0];
101 
102     const unsigned int batchSize    = rOutputShape[0];
103     const unsigned int outputHeight = rOutputShape[heightIndex];
104     const unsigned int outputWidth  = rOutputShape[widthIndex];
105     const unsigned int inputHeight  = rInputShape[heightIndex];
106     const unsigned int inputWidth   = rInputShape[widthIndex];
107 
108     const unsigned int filterHeight = depthwise ? rFilterShape[2] : rFilterShape[heightIndex];
109     const unsigned int filterWidth  = depthwise ? rFilterShape[3] : rFilterShape[widthIndex];
110 
111     const std::vector<float> inputVec = rInputDecoder.DecodeTensor(rInputShape);
112     const std::vector<float> filterVec = rFilterDecoder.DecodeTensor(rFilterShape, depthMultiplier, depthwise);
113 
114     const TensorShape biasShape{outputChannels};
115     const std::vector<float> biasVec = biasEnabled ? pBiasDecoder->DecodeTensor(biasShape) : std::vector<float>();
116 
117     unsigned int depthwiseMultiplierIdx = 0;
118     for (unsigned int batchIdx = 0; batchIdx < batchSize; batchIdx++)
119     {
120         for (unsigned int cOutput = 0; cOutput < outputChannels; cOutput++)
121         {
122             for (unsigned int yOutput = 0; yOutput < outputHeight; yOutput++)
123             {
124                 for (unsigned int xOutput = 0; xOutput < outputWidth; xOutput++)
125                 {
126                     // This loop goes over each output element.
127                     float sum = 0.0f;
128 
129                     // For depthwise, each output channel corresponds to exactly one input channel.
130                     // For normal, must loop over each input channel.
131                     for (unsigned int cInput = 0; cInput < (depthwise ? 1 : inputChannels); cInput++)
132                     {
133                         if (depthwise)
134                         {
135                             depthwiseMultiplierIdx = 0;
136                             cInput = cOutput / depthMultiplier;
137                             depthwiseMultiplierIdx = cOutput % depthMultiplier;
138                         }
139 
140                         for (unsigned int yFilter = 0; yFilter < filterHeight; yFilter++)
141                         {
142                             for (unsigned int xFilter = 0; xFilter < filterWidth; xFilter++)
143                             {
144                                 // This loop goes over each input element for each output element.
145                                 unsigned int filterIndex = 0;
146 
147                                 // Since dimensionality of kernel depends on depthwiseness, so does index.
148                                 if (depthwise)
149                                 {
150                                     filterIndex = depthwiseMultiplierIdx * filterWidth * filterHeight * inputChannels +
151                                                   cInput * filterWidth * filterHeight +
152                                                   yFilter * filterWidth +
153                                                   xFilter;
154                                 }
155                                 else
156                                 {
157                                     // Keep this implementation, as using DataLayoutIndexed::GetIndex causes great
158                                     // performance regression.
159                                     if (dataLayoutIndexed.GetDataLayout() == DataLayout::NHWC)
160                                     {
161                                         filterIndex = cOutput * filterHeight * filterWidth * inputChannels +
162                                                       yFilter * filterWidth * inputChannels +
163                                                       xFilter * inputChannels +
164                                                       cInput;
165                                     }
166                                     else
167                                     {
168                                         filterIndex = cOutput * filterWidth * filterHeight * inputChannels +
169                                                       cInput * filterWidth * filterHeight +
170                                                       yFilter * filterWidth +
171                                                       xFilter;
172                                     }
173                                 }
174 
175                                 unsigned int yInput = yOutput * yStride + yFilter * yDilation;
176                                 unsigned int xInput = xOutput * xStride + xFilter * xDilation;
177 
178                                 float inputValue;
179 
180                                 // Check if we're in the padding.
181                                 if (yInput < paddingTop || yInput >= inputHeight + paddingTop ||
182                                     xInput < paddingLeft || xInput >= inputWidth + paddingLeft)
183                                 {
184                                     inputValue = 0.0f;
185                                 }
186                                 else
187                                 {
188                                     unsigned int inputIndex = 0;
189 
190                                     // Keep this implementation, as using DataLayoutIndexed::GetIndex causes great
191                                     // performance regression.
192                                     if (dataLayoutIndexed.GetDataLayout() == DataLayout::NHWC)
193                                     {
194                                         inputIndex = batchIdx * inputHeight * inputWidth * inputChannels +
195                                                      (yInput - paddingTop) * inputWidth * inputChannels +
196                                                      (xInput - paddingLeft) * inputChannels +
197                                                      cInput;
198                                     }
199                                     else
200                                     {
201                                         inputIndex = batchIdx * inputWidth * inputHeight * inputChannels +
202                                                      inputWidth * inputHeight * cInput +
203                                                      inputWidth * (yInput - paddingTop) +
204                                                      xInput - paddingLeft;
205                                     }
206                                     inputValue = inputVec[inputIndex];
207                                 }
208 
209                                 sum += filterVec[filterIndex] * inputValue;
210                             }
211                         }
212                     }
213 
214                     if (biasEnabled)
215                     {
216                         sum += biasVec[cOutput];
217                     }
218 
219                     unsigned int outIdx;
220                     if (dataLayoutIndexed.GetDataLayout() == DataLayout::NHWC)
221                     {
222                         outIdx =  batchIdx * outputHeight * outputWidth * outputChannels +
223                                   yOutput * outputWidth * outputChannels +
224                                   xOutput * outputChannels +
225                                   cOutput;
226                     }
227                     else
228                     {
229                         outIdx = batchIdx * outputHeight * outputWidth * outputChannels +
230                                  cOutput * outputHeight * outputWidth +
231                                  yOutput * outputWidth +
232                                  xOutput;
233                     }
234 
235                     rOutputEncoder[outIdx];
236                     rOutputEncoder.Set(sum);
237                 }
238             }
239         }
240     }
241 }
242 
243 } // namespace armnn
244