• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_UTILS_H
18 #define ANDROID_ML_NN_COMMON_OPERATIONS_UTILS_H
19 
20 #include "Utils.h"
21 
22 #include <cstdint>
23 #include <vector>
24 
25 // Macro to check if the input parameters for operation are valid or not.
26 #define NN_CHECK(v)                                                     \
27   do {                                                                  \
28     if (!(v)) {                                                         \
29       LOG(ERROR) << "NN_CHECK failed: "  << #v << "'\n";                \
30       return false;                                                     \
31     }                                                                   \
32   } while(0);
33 
34 #define NN_CHECK_EQ(actual, expected)           \
35   NN_CHECK((actual) == (expected))
36 
37 #define NN_OPS_CHECK NN_CHECK
38 
39 namespace android {
40 namespace nn {
41 
42 enum PaddingScheme {
43     kPaddingUnknown = 0,
44     kPaddingSame = 1,
45     kPaddingValid = 2,
46 };
47 
48 // The type and dimensions of an operand.
49 struct Shape {
50     OperandType type;
51     std::vector<uint32_t> dimensions;
52     float scale;
53     int32_t offset;
54 };
55 
56 // Verifies that the two shapes are the same.
57 bool SameShape(const Shape& in1, const Shape& in2);
58 
59 // Sets out to the same shape as in.
60 bool SetShape(const Shape& in, Shape* out);
61 
62 // Return the total number of elements, i.e. all the dimensions multiplied
63 // together. For a scalar, returns one.
64 uint32_t getNumberOfElements(const Shape& shape);
65 
66 uint32_t getNumberOfDimensions(const Shape& shape);
67 
68 uint32_t getSizeOfDimension(const Shape& shape, uint32_t dimensionIdx);
69 
computeOutSize(uint32_t imageSize,uint32_t filterSize,uint32_t stride,uint32_t paddingHead,uint32_t paddingTail)70 inline uint32_t computeOutSize(uint32_t imageSize, uint32_t filterSize, uint32_t stride,
71                                uint32_t paddingHead, uint32_t paddingTail) {
72     return (imageSize - filterSize + stride + paddingHead + paddingTail) / stride;
73 }
74 
75 __wur
76 bool QuantizeMultiplierSmallerThanOne(double double_multiplier,
77                                       int32_t* quantized_multiplier,
78                                       int32_t* right_shift);
79 
80 __wur
81 bool QuantizeMultiplierGreaterThanOne(double double_multiplier,
82                                       int32_t* quantized_multiplier,
83                                       int* left_shift);
84 
85 __wur
86 bool GetQuantizedConvolutionMultipler(const Shape& inputShape,
87                                       const Shape& filterShape,
88                                       const Shape& biasShape,
89                                       const Shape& outputShape,
90                                       float* multiplier);
91 
92 void CalculateActivationRangeUint8(int32_t activation,
93                                    const Shape& outputShape,
94                                    int32_t* act_min,
95                                    int32_t* act_max);
96 
97 void CalculateActivationRangeFloat(int32_t activation,
98                                    float* activation_min,
99                                    float* activation_max);
100 
101 int32_t CalculateInputRadius(int input_integer_bits, int input_left_shift);
102 
calculateExplicitPadding(int32_t in_size,int32_t stride,int32_t filter_size,int32_t padding_implicit,int32_t * padding_head,int32_t * padding_tail)103 inline void calculateExplicitPadding(int32_t in_size, int32_t stride,
104                                      int32_t filter_size, int32_t padding_implicit,
105                                      int32_t* padding_head, int32_t* padding_tail) {
106     *padding_head = 0;
107     *padding_tail = 0;
108 
109     if (padding_implicit == kPaddingSame) {
110         int32_t out_size = (in_size + stride - 1) / stride;
111         int32_t tmp = (out_size - 1) * stride + filter_size;
112         if (tmp > in_size) {
113             *padding_head = (tmp - in_size) / 2;
114             *padding_tail = (tmp - in_size) - *padding_head;
115         }
116     }
117 }
118 
getPaddingScheme(int32_t inWidth,int32_t inHeight,int32_t strideWidth,int32_t strideHeight,int32_t filterWidth,int32_t filterHeight,int32_t paddingLeft,int32_t paddingRight,int32_t paddingTop,int32_t paddingBottom)119 inline PaddingScheme getPaddingScheme(int32_t inWidth, int32_t inHeight,
120                                       int32_t strideWidth, int32_t strideHeight,
121                                       int32_t filterWidth, int32_t filterHeight,
122                                       int32_t paddingLeft, int32_t paddingRight,
123                                       int32_t paddingTop, int32_t paddingBottom) {
124     if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && paddingBottom == 0) {
125         return kPaddingValid;
126     }
127 
128     int32_t expectedPaddingLeft, expectedPaddingRight;
129     int32_t expectedPaddingTop, expectedPaddingBottom;
130 
131     calculateExplicitPadding(inWidth, strideWidth, filterWidth, kPaddingSame,
132                              &expectedPaddingLeft, &expectedPaddingRight);
133     calculateExplicitPadding(inHeight, strideHeight, filterHeight, kPaddingSame,
134                              &expectedPaddingTop, &expectedPaddingBottom);
135     if (expectedPaddingLeft == paddingLeft && expectedPaddingRight == paddingRight &&
136         expectedPaddingTop == paddingTop && expectedPaddingBottom == paddingBottom) {
137         return kPaddingSame;
138     } else {
139         return kPaddingUnknown;
140     }
141 }
142 
143 // TODO: add more documentation from upstream.
144 // Reverse order of bits in the mask to match the expected order in kernel
ReverseMaskBits(int mask,int num_dimensions)145 inline int ReverseMaskBits(int mask, int num_dimensions) {
146   int out = 0;
147   for (int dim = 0; dim < num_dimensions; dim++) {
148     out <<= 1;
149     out += (mask & 1);
150     mask >>= 1;
151   }
152   return out;
153 }
154 
155 // TODO: add more documentation from upstream.
PositiveRemainder(int32_t dividend,int32_t divisor)156 inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) {
157   return (divisor + (dividend % divisor)) % divisor;
158 }
159 
160 // TODO: add more documentation from upstream.
ClampedIndex(int32_t index,int dim,bool pos_stride)161 inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) {
162   return pos_stride
163              ? (index >= dim ? dim
164                              : PositiveRemainder(
165                                    std::min(std::max(index, -dim), dim), dim))
166              : (index < -dim
167                     ? -1
168                     : PositiveRemainder(
169                           std::min(std::max(index, -dim), dim - 1), dim));
170 }
171 
172 // Preparation functions for the corresponding ops
173 bool addMulPrepare(const Shape& in1, const Shape& in2, Shape* out1);
174 
175 bool floorPrepare(const Shape& input, Shape* output);
176 
177 bool dequantizePrepare(const Shape& input, Shape* output);
178 
179 bool depthwiseConvPrepare(const Shape& input,
180                           const Shape& filter,
181                           const Shape& bias,
182                           int32_t padding_left, int32_t padding_right,
183                           int32_t padding_top, int32_t padding_bottom,
184                           int32_t stride_width, int32_t stride_height,
185                           Shape* output);
186 
187 bool convPrepare(const Shape& input,
188                  const Shape& filter,
189                  const Shape& bias,
190                  int32_t padding_left, int32_t padding_right,
191                  int32_t padding_top, int32_t padding_bottom,
192                  int32_t stride_width, int32_t stride_height,
193                  Shape* output);
194 
195 bool genericPoolingPrepare(const Shape& input,
196                            int32_t padding_left, int32_t padding_right,
197                            int32_t padding_top, int32_t padding_bottom,
198                            int32_t stride_width, int32_t stride_height,
199                            int32_t filter_width, int32_t filter_height,
200                            Shape* output);
201 
202 bool genericActivationPrepare(const Shape& input, Shape* output);
203 
204 bool fullyConnectedPrepare(const Shape& input,
205                            const Shape& weights,
206                            const Shape& bias,
207                            Shape* output);
208 
209 bool concatenationPrepare(const std::vector<Shape>& inputShapes,
210                           int32_t axis,
211                           Shape* output);
212 
213 bool genericNormalizationPrepare(const Shape& input, Shape* output);
214 
215 bool reshapePrepare(const Shape& input,
216                     const int32_t* targetDims,
217                     const int32_t targetDimsSize,
218                     Shape* output);
219 
220 bool resizeBilinearPrepare(const Shape& input,
221                            int32_t height,
222                            int32_t width,
223                            Shape* output);
224 
225 bool depthToSpacePrepare(const Shape& input,
226                          int32_t blockSize,
227                          Shape* output);
228 
229 bool spaceToDepthPrepare(const Shape& input,
230                          int32_t blockSize,
231                          Shape* output);
232 
233 bool embeddingLookupPrepare(const Shape &valueShape,
234                             const Shape &lookupShape,
235                             Shape *outputShape);
236 
237 bool hashtableLookupPrepare(const Shape &lookupShape,
238                             const Shape &keyShape,
239                             const Shape &valueShape,
240                             Shape *outputShape,
241                             Shape *hitShape);
242 
243 bool padPrepare(const Shape& input,
244                 const int32_t* paddingsData,
245                 const Shape& paddingsShape,
246                 Shape* output);
247 
248 bool batchToSpacePrepare(const Shape& input,
249                          const int32_t* blockSizeData,
250                          const Shape& blockSizeShape,
251                          Shape* output);
252 
253 bool spaceToBatchPrepare(const Shape& input,
254                          const int32_t* blockSizeData,
255                          const Shape& blockSizeShape,
256                          const int32_t* paddingsData,
257                          const Shape& paddingsShape,
258                          Shape* output);
259 
260 bool squeezePrepare(const Shape& input,
261                     const int32_t* squeezeDims,
262                     const Shape& squeezeDimsShape,
263                     Shape* output);
264 
265 bool transposePrepare(const Shape& input,
266                       const int32_t* permData,
267                       const Shape& permShape,
268                       Shape* output);
269 
270 bool meanPrepare(const Shape& input,
271                  const int32_t* axisData,
272                  const Shape& axisShape,
273                  bool keepDims,
274                  Shape* output);
275 
276 bool stridedSlicePrepare(const Shape& input,
277                          const int32_t* beginData, const Shape& beginShape,
278                          const int32_t* endData, const Shape& endShape,
279                          const int32_t* stridesData, const Shape& stridesShape,
280                          int32_t beginMask, int32_t endMask, int32_t shrinkAxisMask,
281                          Shape* output);
282 } // namespace nn
283 } // namespace android
284 
285 #endif // ANDROID_ML_NN_COMMON_OPERATIONS_UTILS_H
286