• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_H
18 #define ANDROID_ML_NN_COMMON_OPERATIONS_H
19 
20 #include "operations/EmbeddingLookup.h"
21 #include "operations/HashtableLookup.h"
22 #include "operations/LSHProjection.h"
23 #include "operations/LSTM.h"
24 #include "operations/RNN.h"
25 #include "operations/SVDF.h"
26 
27 #include <stddef.h>
28 
29 #include <cstdint>
30 #include <vector>
31 
32 namespace android {
33 namespace nn {
34 
35 struct Shape;
36 
37 bool addFloat32(const float* in1, const Shape& shape1,
38                 const float* in2, const Shape& shape2,
39                 int32_t activation,
40                 float* out, const Shape& shapeOut);
41 bool addQuant8(const uint8_t* in1, const Shape& shape1,
42                const uint8_t* in2, const Shape& shape2,
43                int32_t activation,
44                uint8_t* out, const Shape& shapeOut);
45 
46 bool mulFloat32(const float* in1, const Shape& shape1,
47                 const float* in2, const Shape& shape2,
48                 int32_t activation,
49                 float* out, const Shape& shapeOut);
50 bool mulQuant8(const uint8_t* in1, const Shape& shape1,
51                const uint8_t* in2, const Shape& shape2,
52                int32_t activation,
53                uint8_t* out, const Shape& shapeOut);
54 
55 bool floorFloat32(const float* inputData,
56                   float* outputData,
57                   const Shape& shape);
58 
59 bool dequantizeQuant8ToFloat32(const uint8_t* inputData,
60                                float* outputData,
61                                const Shape& shape);
62 
63 bool depthwiseConvFloat32(const float* inputData, const Shape& inputShape,
64                           const float* filterData, const Shape& filterShape,
65                           const float* biasData, const Shape& biasShape,
66                           int32_t padding_left, int32_t padding_right,
67                           int32_t padding_top, int32_t padding_bottom,
68                           int32_t stride_width, int32_t stride_height,
69                           int32_t depth_multiplier, int32_t activation,
70                           float* outputData, const Shape& outputShape);
71 bool depthwiseConvQuant8(const uint8_t* inputData, const Shape& inputShape,
72                          const uint8_t* filterData, const Shape& filterShape,
73                          const int32_t* biasData, const Shape& biasShape,
74                          int32_t padding_left, int32_t padding_right,
75                          int32_t padding_top, int32_t padding_bottom,
76                          int32_t stride_width, int32_t stride_height,
77                          int32_t depth_multiplier, int32_t activation,
78                          uint8_t* outputData, const Shape& outputShape);
79 
80 bool convFloat32(const float* inputData, const Shape& inputShape,
81                  const float* filterData, const Shape& filterShape,
82                  const float* biasData, const Shape& biasShape,
83                  int32_t padding_left, int32_t padding_right,
84                  int32_t padding_top, int32_t padding_bottom,
85                  int32_t stride_width, int32_t stride_height,
86                  int32_t activation,
87                  float* outputData, const Shape& outputShape);
88 bool convQuant8(const uint8_t* inputData, const Shape& inputShape,
89                 const uint8_t* filterData, const Shape& filterShape,
90                 const int32_t* biasData, const Shape& biasShape,
91                 int32_t padding_left, int32_t padding_right,
92                 int32_t padding_top, int32_t padding_bottom,
93                 int32_t stride_width, int32_t stride_height,
94                 int32_t activation,
95                 uint8_t* outputData, const Shape& outputShape);
96 
97 bool averagePoolFloat32(const float* inputData, const Shape& inputShape,
98                         int32_t padding_left, int32_t padding_right,
99                         int32_t padding_top, int32_t padding_bottom,
100                         int32_t stride_width, int32_t stride_height,
101                         int32_t filter_width, int32_t filter_height, int32_t activation,
102                         float* outputData, const Shape& outputShape);
103 bool averagePoolQuant8(const uint8_t* inputData, const Shape& inputShape,
104                        int32_t padding_left, int32_t padding_right,
105                        int32_t padding_top, int32_t padding_bottom,
106                        int32_t stride_width, int32_t stride_height,
107                        int32_t filter_width, int32_t filter_height, int32_t activation,
108                        uint8_t* outputData, const Shape& outputShape);
109 bool l2PoolFloat32(const float* inputData, const Shape& inputShape,
110                    int32_t padding_left, int32_t padding_right,
111                    int32_t padding_top, int32_t padding_bottom,
112                    int32_t stride_width, int32_t stride_height,
113                    int32_t filter_width, int32_t filter_height, int32_t activation,
114                    float* outputData, const Shape& outputShape);
115 bool maxPoolFloat32(const float* inputData, const Shape& inputShape,
116                     int32_t padding_left, int32_t padding_right,
117                     int32_t padding_top, int32_t padding_bottom,
118                     int32_t stride_width, int32_t stride_height,
119                     int32_t filter_width, int32_t filter_height, int32_t activation,
120                     float* outputData, const Shape& outputShape);
121 bool maxPoolQuant8(const uint8_t* inputData, const Shape& inputShape,
122                    int32_t padding_left, int32_t padding_right,
123                    int32_t padding_top, int32_t padding_bottom,
124                    int32_t stride_width, int32_t stride_height,
125                    int32_t filter_width, int32_t filter_height, int32_t activation,
126                    uint8_t* outputData, const Shape& outputShape);
127 
128 bool reluFloat32(const float* inputData, const Shape& inputShape,
129                  float* outputData, const Shape& outputShape);
130 bool relu1Float32(const float* inputData, const Shape& inputShape,
131                   float* outputData, const Shape& outputShape);
132 bool relu6Float32(const float* inputData, const Shape& inputShape,
133                   float* outputData, const Shape& outputShape);
134 bool tanhFloat32(const float* inputData, const Shape& inputShape,
135                  float* outputData, const Shape& outputShape);
136 bool logisticFloat32(const float* inputData, const Shape& inputShape,
137                      float* outputData, const Shape& outputShape);
138 bool softmaxFloat32(const float* inputData, const Shape& inputShape,
139                     const float beta,
140                     float* outputData, const Shape& outputShape);
141 bool reluQuant8(const uint8_t* inputData, const Shape& inputShape,
142                 uint8_t* outputData, const Shape& outputShape);
143 bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape,
144                  uint8_t* outputData, const Shape& outputShape);
145 bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape,
146                  uint8_t* outputData, const Shape& outputShape);
147 bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape,
148                     uint8_t* outputData, const Shape& outputShape);
149 bool softmaxQuant8(const uint8_t* inputData, const Shape& inputShape,
150                    const float beta,
151                    uint8_t* outputData, const Shape& outputShape);
152 
153 bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape,
154                            const float* weights, const Shape& weightsShape,
155                            const float* biasData, const Shape& biasShape,
156                            int32_t activation,
157                            float* outputData, const Shape& outputShape);
158 bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape,
159                           const uint8_t* weights, const Shape& weightsShape,
160                           const int32_t* biasData, const Shape& biasShape,
161                           int32_t activation,
162                           uint8_t* outputData, const Shape& outputShape);
163 
164 bool concatenationFloat32(const std::vector<const float*>& inputDataPtrs,
165                           const std::vector<Shape>& inputShapes, int32_t axis,
166                           float* outputData, const Shape& outputShape);
167 bool concatenationQuant8(const std::vector<const uint8_t*>& inputDataPtrs,
168                          const std::vector<Shape>& inputShapes, int32_t axis,
169                          uint8_t* outputData, const Shape& outputShape);
170 
171 bool l2normFloat32(const float* inputData, const Shape& inputShape,
172                    float* outputData, const Shape& outputShape);
173 bool l2normQuant8(const uint8_t* inputData, const Shape& inputShape,
174                   uint8_t* outputData, const Shape& outputShape);
175 bool localResponseNormFloat32(const float* inputData, const Shape& inputShape,
176                               int32_t radius, float bias, float alpha, float beta,
177                               float* outputData, const Shape& outputShape);
178 
179 bool reshapeGeneric(const void* inputData, const Shape& inputShape,
180                     void* outputData, const Shape& outputShape);
181 
182 bool resizeBilinearFloat32(const float* inputData,
183                            const Shape& inputShape,
184                            float* outputData,
185                            const Shape& outputShape);
186 
187 bool depthToSpaceGeneric(const uint8_t* inputData, const Shape& inputShape,
188                          int32_t blockSize,
189                          uint8_t* outputData, const Shape& outputShape);
190 
191 bool spaceToDepthGeneric(const uint8_t* inputData, const Shape& inputShape,
192                          int32_t blockSize,
193                          uint8_t* outputData, const Shape& outputShape);
194 
195 bool padGeneric(const uint8_t* inputData, const Shape& inputShape,
196                 const int32_t* paddings,
197                 uint8_t* outputData, const Shape& outputShape);
198 
199 bool batchToSpaceGeneric(const uint8_t* inputData, const Shape& inputShape,
200                          const int32_t* blockSize,
201                          uint8_t* outputData, const Shape& outputShape);
202 
203 bool spaceToBatchGeneric(const uint8_t* inputData, const Shape& inputShape,
204                          const int32_t* blockSize,
205                          const int32_t* padding, const Shape& paddingShape,
206                          uint8_t* outputData, const Shape& outputShape);
207 
208 bool subFloat32(const float* in1, const Shape& shape1,
209                 const float* in2, const Shape& shape2,
210                 int32_t activation,
211                 float* out, const Shape& shapeOut);
212 
213 bool squeezeGeneric(const void* inputData, const Shape& inputShape,
214                     void* outputData, const Shape& outputShape);
215 
216 bool divFloat32(const float* in1, const Shape& shape1,
217                 const float* in2, const Shape& shape2,
218                 int32_t activation,
219                 float* out, const Shape& shapeOut);
220 
221 bool transposeGeneric(const uint8_t* inputData, const Shape& inputShape,
222                       const int32_t* perm, const Shape& permShape,
223                       uint8_t* outputData, const Shape& outputShape);
224 
225 bool meanGeneric(const uint8_t* inputData, const Shape& inputShape,
226                  const int32_t* axis, const Shape& axisShape, bool keepDims,
227                  uint8_t* outputData, const Shape& outputShape);
228 
229 bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape,
230                          const int32_t* beginData, const int32_t* endData,
231                          const int32_t* stridesData,
232                          int32_t beginMask, int32_t endMask, int32_t shrinkAxisMask,
233                          uint8_t* outputData, const Shape& outputShape);
234 } // namespace nn
235 } // namespace android
236 #endif // ANDROID_ML_NN_COMMON_OPERATIONS_H
237