• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
17 
18 #include <stdint.h>
19 #include <sys/types.h>
20 
21 #include "public/gemmlowp.h"
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/legacy_types.h"
24 #include "tensorflow/lite/kernels/internal/reference/conv.h"
25 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
26 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
27 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
28 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
29 #include "tensorflow/lite/kernels/internal/types.h"
30 
31 namespace tflite {
32 
33 namespace reference_ops {
34 
35 static constexpr int kDepthwiseReverseShift = -1;
36 
ShapeFromDims(const tflite::Dims<4> & dims,RuntimeShape * shape)37 inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
38   shape->BuildFrom(
39       {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
40 }
41 
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)42 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
43                           const float* filter_data, const Dims<4>& filter_dims,
44                           const float* bias_data, const Dims<4>& bias_dims,
45                           int stride_width, int stride_height,
46                           int dilation_width_factor, int dilation_height_factor,
47                           int pad_width, int pad_height, int depth_multiplier,
48                           float output_activation_min,
49                           float output_activation_max, float* output_data,
50                           const Dims<4>& output_dims) {
51   tflite::DepthwiseParams op_params;
52   // Padding type is ignored, but still set.
53   op_params.padding_type = PaddingType::kSame;
54   op_params.padding_values.width = pad_width;
55   op_params.padding_values.height = pad_height;
56   op_params.stride_width = stride_width;
57   op_params.stride_height = stride_height;
58   op_params.dilation_width_factor = dilation_width_factor;
59   op_params.dilation_height_factor = dilation_height_factor;
60   op_params.depth_multiplier = depth_multiplier;
61   op_params.float_activation_min = output_activation_min;
62   op_params.float_activation_max = output_activation_max;
63 
64   DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
65                 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
66                 bias_data, DimsToShape(output_dims), output_data);
67 }
68 
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)69 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
70                           const float* filter_data, const Dims<4>& filter_dims,
71                           const float* bias_data, const Dims<4>& bias_dims,
72                           int stride_width, int stride_height, int pad_width,
73                           int pad_height, int depth_multiplier,
74                           float output_activation_min,
75                           float output_activation_max, float* output_data,
76                           const Dims<4>& output_dims) {
77   DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
78                 bias_dims, stride_width, stride_height, 1, 1, pad_width,
79                 pad_height, depth_multiplier, output_activation_min,
80                 output_activation_max, output_data, output_dims);
81 }
82 
83 // Legacy, for compatibility with old checked-in code.
84 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)85 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
86                    const float* filter_data, const Dims<4>& filter_dims,
87                    const float* bias_data, const Dims<4>& bias_dims,
88                    int stride_width, int stride_height, int pad_width,
89                    int pad_height, int depth_multiplier, float* output_data,
90                    const Dims<4>& output_dims) {
91   float output_activation_min, output_activation_max;
92   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
93   DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
94                 bias_dims, stride_width, stride_height, pad_width, pad_height,
95                 depth_multiplier, output_activation_min, output_activation_max,
96                 output_data, output_dims);
97 }
98 
99 // Legacy, for compatibility with old checked-in code.
100 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)101 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
102                    const float* filter_data, const Dims<4>& filter_dims,
103                    const float* bias_data, const Dims<4>& bias_dims, int stride,
104                    int pad_width, int pad_height, int depth_multiplier,
105                    float* output_data, const Dims<4>& output_dims) {
106   DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
107                     bias_dims, stride, stride, pad_width, pad_height,
108                     depth_multiplier, output_data, output_dims);
109 }
110 
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)111 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
112                           int32 input_offset, const uint8* filter_data,
113                           const Dims<4>& filter_dims, int32 filter_offset,
114                           const int32* bias_data, const Dims<4>& bias_dims,
115                           int stride_width, int stride_height,
116                           int dilation_width_factor, int dilation_height_factor,
117                           int pad_width, int pad_height, int depth_multiplier,
118                           int32 output_offset, int32 output_multiplier,
119                           int output_shift, int32 output_activation_min,
120                           int32 output_activation_max, uint8* output_data,
121                           const Dims<4>& output_dims) {
122   tflite::DepthwiseParams op_params;
123   // Padding type is ignored, but still set.
124   op_params.padding_type = PaddingType::kSame;
125   op_params.padding_values.width = pad_width;
126   op_params.padding_values.height = pad_height;
127   op_params.stride_width = stride_width;
128   op_params.stride_height = stride_height;
129   op_params.dilation_width_factor = dilation_width_factor;
130   op_params.dilation_height_factor = dilation_height_factor;
131   op_params.depth_multiplier = depth_multiplier;
132   op_params.quantized_activation_min = output_activation_min;
133   op_params.quantized_activation_max = output_activation_max;
134   op_params.input_offset = input_offset;
135   op_params.weights_offset = filter_offset;
136   op_params.output_offset = output_offset;
137   op_params.output_multiplier = output_multiplier;
138   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
139   op_params.output_shift = kDepthwiseReverseShift * output_shift;
140 
141   DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
142                 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
143                 bias_data, DimsToShape(output_dims), output_data);
144 }
145 
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)146 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
147                           int32 input_offset, const uint8* filter_data,
148                           const Dims<4>& filter_dims, int32 filter_offset,
149                           const int32* bias_data, const Dims<4>& bias_dims,
150                           int stride_width, int stride_height, int pad_width,
151                           int pad_height, int depth_multiplier,
152                           int32 output_offset, int32 output_multiplier,
153                           int output_shift, int32 output_activation_min,
154                           int32 output_activation_max, uint8* output_data,
155                           const Dims<4>& output_dims) {
156   DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
157                 filter_offset, bias_data, bias_dims, stride_width,
158                 stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
159                 output_offset, output_multiplier, output_shift,
160                 output_activation_min, output_activation_max, output_data,
161                 output_dims);
162 }
163 
164 // Legacy, for compatibility with old checked-in code.
165 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)166 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
167                    int32 input_offset, const uint8* filter_data,
168                    const Dims<4>& filter_dims, int32 filter_offset,
169                    const int32* bias_data, const Dims<4>& bias_dims,
170                    int stride_width, int stride_height, int pad_width,
171                    int pad_height, int depth_multiplier, int32 output_offset,
172                    int32 output_multiplier, int output_shift,
173                    int32 output_activation_min, int32 output_activation_max,
174                    uint8* output_data, const Dims<4>& output_dims) {
175   if (Ac == FusedActivationFunctionType::kNone) {
176     TFLITE_DCHECK_EQ(output_activation_min, 0);
177     TFLITE_DCHECK_EQ(output_activation_max, 255);
178   }
179   DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
180                 filter_offset, bias_data, bias_dims, stride_width,
181                 stride_height, pad_width, pad_height, depth_multiplier,
182                 output_offset, output_multiplier, output_shift,
183                 output_activation_min, output_activation_max, output_data,
184                 output_dims);
185 }
186 
187 // Legacy, for compatibility with old checked-in code.
188 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)189 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
190                    int32 input_offset, const uint8* filter_data,
191                    const Dims<4>& filter_dims, int32 filter_offset,
192                    const int32* bias_data, const Dims<4>& bias_dims, int stride,
193                    int pad_width, int pad_height, int depth_multiplier,
194                    int32 output_offset, int32 output_multiplier,
195                    int output_shift, int32 output_activation_min,
196                    int32 output_activation_max, uint8* output_data,
197                    const Dims<4>& output_dims) {
198   DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
199                     filter_dims, filter_offset, bias_data, bias_dims, stride,
200                     stride, pad_width, pad_height, depth_multiplier,
201                     output_offset, output_multiplier, output_shift,
202                     output_activation_min, output_activation_max, output_data,
203                     output_dims);
204 }
205 
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)206 inline void Conv(const float* input_data, const Dims<4>& input_dims,
207                  const float* filter_data, const Dims<4>& filter_dims,
208                  const float* bias_data, const Dims<4>& bias_dims,
209                  int stride_width, int stride_height, int dilation_width_factor,
210                  int dilation_height_factor, int pad_width, int pad_height,
211                  float output_activation_min, float output_activation_max,
212                  float* output_data, const Dims<4>& output_dims,
213                  float* im2col_data, const Dims<4>& im2col_dims) {
214   tflite::ConvParams op_params;
215   // Padding type is ignored, but still set.
216   op_params.padding_type = PaddingType::kSame;
217   op_params.padding_values.width = pad_width;
218   op_params.padding_values.height = pad_height;
219   op_params.stride_width = stride_width;
220   op_params.stride_height = stride_height;
221   op_params.dilation_width_factor = dilation_width_factor;
222   op_params.dilation_height_factor = dilation_height_factor;
223   op_params.float_activation_min = output_activation_min;
224   op_params.float_activation_max = output_activation_max;
225 
226   Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
227        filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
228        output_data, DimsToShape(im2col_dims), im2col_data);
229 }
230 
231 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)232 void Conv(const float* input_data, const Dims<4>& input_dims,
233           const float* filter_data, const Dims<4>& filter_dims,
234           const float* bias_data, const Dims<4>& bias_dims, int stride_width,
235           int stride_height, int dilation_width_factor,
236           int dilation_height_factor, int pad_width, int pad_height,
237           float* output_data, const Dims<4>& output_dims, float* im2col_data,
238           const Dims<4>& im2col_dims) {
239   float output_activation_min, output_activation_max;
240   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
241   Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
242        stride_width, stride_height, dilation_width_factor,
243        dilation_height_factor, pad_width, pad_height, output_activation_min,
244        output_activation_max, output_data, output_dims, im2col_data,
245        im2col_dims);
246 }
247 
248 // legacy, for compatibility with old checked-in code
249 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)250 void Conv(const float* input_data, const Dims<4>& input_dims,
251           const float* filter_data, const Dims<4>& filter_dims,
252           const float* bias_data, const Dims<4>& bias_dims, int stride_width,
253           int stride_height, int pad_width, int pad_height, float* output_data,
254           const Dims<4>& output_dims, float* im2col_data,
255           const Dims<4>& im2col_dims) {
256   float output_activation_min, output_activation_max;
257   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
258   Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
259        stride_width, stride_height, 1, 1, pad_width, pad_height,
260        output_activation_min, output_activation_max, output_data, output_dims,
261        im2col_data, im2col_dims);
262 }
263 
264 // legacy, for compatibility with old checked-in code
265 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)266 void Conv(const float* input_data, const Dims<4>& input_dims,
267           const float* filter_data, const Dims<4>& filter_dims,
268           const float* bias_data, const Dims<4>& bias_dims, int stride,
269           int pad_width, int pad_height, float* output_data,
270           const Dims<4>& output_dims, float* im2col_data,
271           const Dims<4>& im2col_dims) {
272   Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
273            bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
274            output_dims, im2col_data, im2col_dims);
275 }
276 
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)277 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
278                  int32 input_offset, const uint8* filter_data,
279                  const Dims<4>& filter_dims, int32 filter_offset,
280                  const int32* bias_data, const Dims<4>& bias_dims,
281                  int stride_width, int stride_height, int dilation_width_factor,
282                  int dilation_height_factor, int pad_width, int pad_height,
283                  int32 output_offset, int32 output_multiplier, int output_shift,
284                  int32 output_activation_min, int32 output_activation_max,
285                  uint8* output_data, const Dims<4>& output_dims,
286                  uint8* im2col_data, const Dims<4>& im2col_dims,
287                  gemmlowp::GemmContext* gemmlowp_context) {
288   tflite::ConvParams op_params;
289   // Padding type is ignored, but still set.
290   op_params.padding_type = PaddingType::kSame;
291   op_params.padding_values.width = pad_width;
292   op_params.padding_values.height = pad_height;
293   op_params.stride_width = stride_width;
294   op_params.stride_height = stride_height;
295   op_params.dilation_width_factor = dilation_width_factor;
296   op_params.dilation_height_factor = dilation_height_factor;
297   op_params.input_offset = input_offset;
298   op_params.weights_offset = filter_offset;
299   op_params.output_offset = output_offset;
300   op_params.output_multiplier = output_multiplier;
301   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
302   op_params.output_shift = kReverseShift * output_shift;
303   op_params.quantized_activation_min = output_activation_min;
304   op_params.quantized_activation_max = output_activation_max;
305 
306   Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
307        filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
308        output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context);
309 }
310 
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)311 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
312                  int32 input_offset, const uint8* filter_data,
313                  const Dims<4>& filter_dims, int32 filter_offset,
314                  const int32* bias_data, const Dims<4>& bias_dims,
315                  int stride_width, int stride_height, int pad_width,
316                  int pad_height, int32 output_offset, int32 output_multiplier,
317                  int output_shift, int32 output_activation_min,
318                  int32 output_activation_max, uint8* output_data,
319                  const Dims<4>& output_dims, uint8* im2col_data,
320                  const Dims<4>& im2col_dims,
321                  gemmlowp::GemmContext* gemmlowp_context) {
322   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
323        filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
324        pad_width, pad_height, output_offset, output_multiplier, output_shift,
325        output_activation_min, output_activation_max, output_data, output_dims,
326        im2col_data, im2col_dims, gemmlowp_context);
327 }
328 
329 // legacy, for compatibility with old checked-in code
330 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)331 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
332                  int32 input_offset, const uint8* filter_data,
333                  const Dims<4>& filter_dims, int32 filter_offset,
334                  const int32* bias_data, const Dims<4>& bias_dims,
335                  int stride_width, int stride_height, int pad_width,
336                  int pad_height, int32 output_offset, int32 output_multiplier,
337                  int output_shift, int32 output_activation_min,
338                  int32 output_activation_max, uint8* output_data,
339                  const Dims<4>& output_dims, uint8* im2col_data,
340                  const Dims<4>& im2col_dims,
341                  gemmlowp::GemmContext* gemmlowp_context) {
342   static_assert(Ac == FusedActivationFunctionType::kNone ||
343                     Ac == FusedActivationFunctionType::kRelu ||
344                     Ac == FusedActivationFunctionType::kRelu6 ||
345                     Ac == FusedActivationFunctionType::kRelu1,
346                 "");
347   if (Ac == FusedActivationFunctionType::kNone) {
348     TFLITE_DCHECK_EQ(output_activation_min, 0);
349     TFLITE_DCHECK_EQ(output_activation_max, 255);
350   }
351   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
352        filter_offset, bias_data, bias_dims, stride_width, stride_height,
353        pad_width, pad_height, output_offset, output_multiplier, output_shift,
354        output_activation_min, output_activation_max, output_data, output_dims,
355        im2col_data, im2col_dims, gemmlowp_context);
356 }
357 
358 // legacy, for compatibility with old checked-in code
359 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)360 void Conv(const uint8* input_data, const Dims<4>& input_dims,
361           int32 input_offset, const uint8* filter_data,
362           const Dims<4>& filter_dims, int32 filter_offset,
363           const int32* bias_data, const Dims<4>& bias_dims, int stride,
364           int pad_width, int pad_height, int32 output_offset,
365           int32 output_multiplier, int output_shift,
366           int32 output_activation_min, int32 output_activation_max,
367           uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
368           const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) {
369   Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
370            filter_offset, bias_data, bias_dims, stride, stride, pad_width,
371            pad_height, output_offset, output_multiplier, output_shift,
372            output_activation_min, output_activation_max, output_data,
373            output_dims, im2col_data, im2col_dims, gemmlowp_context);
374 }
375 
TransposeConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)376 inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
377                           const float* filter_data, const Dims<4>& filter_dims,
378                           int stride_width, int stride_height, int pad_width,
379                           int pad_height, float* output_data,
380                           const Dims<4>& output_dims, float* im2col_data,
381                           const Dims<4>& im2col_dims) {
382   tflite::ConvParams op_params;
383   // Padding type is ignored, but still set.
384   op_params.padding_type = PaddingType::kSame;
385   op_params.padding_values.width = pad_width;
386   op_params.padding_values.height = pad_height;
387   op_params.stride_width = stride_width;
388   op_params.stride_height = stride_height;
389 
390   TransposeConv(op_params, DimsToShape(input_dims), input_data,
391                 DimsToShape(filter_dims), filter_data,
392                 /*bias_shape*/ RuntimeShape(), /*bias*/ nullptr,
393                 DimsToShape(output_dims), output_data, DimsToShape(im2col_dims),
394                 im2col_data);
395 }
396 
TransposeConv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)397 inline void TransposeConv(
398     const ConvParams& params, const RuntimeShape& input_shape,
399     const float* input_data, const RuntimeShape& filter_shape,
400     const float* filter_data, const RuntimeShape& output_shape,
401     float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
402   TransposeConv(params, input_shape, input_data, filter_shape, filter_data,
403                 /*bias_shape*/ RuntimeShape(), /*bias*/ nullptr, output_shape,
404                 output_data, im2col_shape, im2col_data);
405 }
406 
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)407 inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
408                            const float* weights_data,
409                            const Dims<4>& weights_dims, const float* bias_data,
410                            const Dims<4>& bias_dims,
411                            float output_activation_min,
412                            float output_activation_max, float* output_data,
413                            const Dims<4>& output_dims) {
414   tflite::FullyConnectedParams op_params;
415   op_params.float_activation_min = output_activation_min;
416   op_params.float_activation_max = output_activation_max;
417 
418   FullyConnected(op_params, DimsToShape(input_dims), input_data,
419                  DimsToShape(weights_dims), weights_data,
420                  DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
421                  output_data);
422 }
423 
424 // legacy, for compatibility with old checked-in code
425 template <FusedActivationFunctionType Ac>
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)426 void FullyConnected(const float* input_data, const Dims<4>& input_dims,
427                     const float* weights_data, const Dims<4>& weights_dims,
428                     const float* bias_data, const Dims<4>& bias_dims,
429                     float* output_data, const Dims<4>& output_dims) {
430   float output_activation_min, output_activation_max;
431   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
432   FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
433                  bias_dims, output_activation_min, output_activation_max,
434                  output_data, output_dims);
435 }
436 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext *)437 inline void FullyConnected(
438     const FullyConnectedParams& params, const RuntimeShape& input_shape,
439     const uint8* input_data, const RuntimeShape& filter_shape,
440     const uint8* filter_data, const RuntimeShape& bias_shape,
441     const int32* bias_data, const RuntimeShape& output_shape,
442     uint8* output_data, gemmlowp::GemmContext*) {
443   FullyConnected(params, input_shape, input_data, filter_shape, filter_data,
444                  bias_shape, bias_data, output_shape, output_data);
445 }
446 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,gemmlowp::GemmContext *)447 inline void FullyConnected(
448     const FullyConnectedParams& params, const RuntimeShape& input_shape,
449     const uint8* input_data, const RuntimeShape& filter_shape,
450     const uint8* filter_data, const RuntimeShape& bias_shape,
451     const int32* bias_data, const RuntimeShape& output_shape,
452     int16* output_data, gemmlowp::GemmContext*) {
453   FullyConnected(params, input_shape, input_data, filter_shape, filter_data,
454                  bias_shape, bias_data, output_shape, output_data);
455 }
456 
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)457 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
458                            int32 input_offset, const uint8* filter_data,
459                            const Dims<4>& filter_dims, int32 filter_offset,
460                            const int32* bias_data, const Dims<4>& bias_dims,
461                            int32 output_offset, int32 output_multiplier,
462                            int output_shift, int32 output_activation_min,
463                            int32 output_activation_max, uint8* output_data,
464                            const Dims<4>& output_dims,
465                            gemmlowp::GemmContext* gemmlowp_context) {
466   tflite::FullyConnectedParams op_params;
467   op_params.input_offset = input_offset;
468   op_params.weights_offset = filter_offset;
469   op_params.output_offset = output_offset;
470   op_params.output_multiplier = output_multiplier;
471   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
472   op_params.output_shift = kReverseShift * output_shift;
473   op_params.quantized_activation_min = output_activation_min;
474   op_params.quantized_activation_max = output_activation_max;
475 
476   FullyConnected(op_params, DimsToShape(input_dims), input_data,
477                  DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
478                  bias_data, DimsToShape(output_dims), output_data,
479                  gemmlowp_context);
480 }
481 
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)482 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
483                            int32 input_offset, const uint8* filter_data,
484                            const Dims<4>& filter_dims, int32 filter_offset,
485                            const int32* bias_data, const Dims<4>& bias_dims,
486                            int32 output_offset, int32 output_multiplier,
487                            int output_shift, int32 output_activation_min,
488                            int32 output_activation_max, int16* output_data,
489                            const Dims<4>& output_dims,
490                            gemmlowp::GemmContext* gemmlowp_context) {
491   tflite::FullyConnectedParams op_params;
492   op_params.input_offset = input_offset;
493   op_params.weights_offset = filter_offset;
494   op_params.output_offset = output_offset;
495   op_params.output_multiplier = output_multiplier;
496   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
497   op_params.output_shift = kReverseShift * output_shift;
498   op_params.quantized_activation_min = output_activation_min;
499   op_params.quantized_activation_max = output_activation_max;
500 
501   FullyConnected(op_params, DimsToShape(input_dims), input_data,
502                  DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
503                  bias_data, DimsToShape(output_dims), output_data,
504                  gemmlowp_context);
505 }
506 
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext *)507 inline void ShuffledFullyConnected(
508     const FullyConnectedParams& params, const RuntimeShape& input_shape,
509     const uint8* input_data, const RuntimeShape& weights_shape,
510     const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
511     const int32* bias_data, const RuntimeShape& output_shape,
512     int16* output_data, uint8* shuffled_input_workspace_data,
513     gemmlowp::GemmContext*) {
514   ShuffledFullyConnected(params, input_shape, input_data, weights_shape,
515                          shuffled_weights_data, bias_shape, bias_data,
516                          output_shape, output_data,
517                          shuffled_input_workspace_data);
518 }
519 
ShuffledFullyConnected(const uint8 * input_data,const Dims<4> & input_dims,const uint8 * shuffled_weights_data,const Dims<4> & weights_dims,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemmlowp_context)520 inline void ShuffledFullyConnected(
521     const uint8* input_data, const Dims<4>& input_dims,
522     const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
523     const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
524     int output_shift, int32 output_activation_min, int32 output_activation_max,
525     int16* output_data, const Dims<4>& output_dims,
526     uint8* shuffled_input_workspace_data,
527     gemmlowp::GemmContext* gemmlowp_context) {
528   tflite::FullyConnectedParams op_params;
529   op_params.output_multiplier = output_multiplier;
530   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
531   op_params.output_shift = kReverseShift * output_shift;
532   op_params.quantized_activation_min = output_activation_min;
533   op_params.quantized_activation_max = output_activation_max;
534 
535   ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
536                          DimsToShape(weights_dims), shuffled_weights_data,
537                          DimsToShape(bias_dims), bias_data,
538                          DimsToShape(output_dims), output_data,
539                          shuffled_input_workspace_data, gemmlowp_context);
540 }
541 
542 // legacy, for compatibility with old checked-in code
543 template <FusedActivationFunctionType Ac>
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)544 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
545                     int32 input_offset, const uint8* filter_data,
546                     const Dims<4>& filter_dims, int32 filter_offset,
547                     const int32* bias_data, const Dims<4>& bias_dims,
548                     int32 output_offset, int32 output_multiplier,
549                     int output_shift, int32 output_activation_min,
550                     int32 output_activation_max, uint8* output_data,
551                     const Dims<4>& output_dims,
552                     gemmlowp::GemmContext* gemmlowp_context) {
553   static_assert(Ac == FusedActivationFunctionType::kNone ||
554                     Ac == FusedActivationFunctionType::kRelu ||
555                     Ac == FusedActivationFunctionType::kRelu6 ||
556                     Ac == FusedActivationFunctionType::kRelu1,
557                 "");
558   if (Ac == FusedActivationFunctionType::kNone) {
559     TFLITE_DCHECK_EQ(output_activation_min, 0);
560     TFLITE_DCHECK_EQ(output_activation_max, 255);
561   }
562   FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
563                  filter_offset, bias_data, bias_dims, output_offset,
564                  output_multiplier, output_shift, output_activation_min,
565                  output_activation_max, output_data, output_dims,
566                  gemmlowp_context);
567 }
568 
LstmCell(const float * input_data,const Dims<4> & input_dims,const float * prev_activ_data,const Dims<4> & prev_activ_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,const float * prev_state_data,const Dims<4> & prev_state_dims,float * output_state_data,const Dims<4> & output_state_dims,float * output_activ_data,const Dims<4> & output_activ_dims,float * concat_temp_data,const Dims<4> & concat_temp_dims,float * activ_temp_data,const Dims<4> & activ_temp_dims)569 inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
570                      const float* prev_activ_data,
571                      const Dims<4>& prev_activ_dims, const float* weights_data,
572                      const Dims<4>& weights_dims, const float* bias_data,
573                      const Dims<4>& bias_dims, const float* prev_state_data,
574                      const Dims<4>& prev_state_dims, float* output_state_data,
575                      const Dims<4>& output_state_dims, float* output_activ_data,
576                      const Dims<4>& output_activ_dims, float* concat_temp_data,
577                      const Dims<4>& concat_temp_dims, float* activ_temp_data,
578                      const Dims<4>& activ_temp_dims) {
579   tflite::LstmCellParams op_params;
580   // Float LSTM cell does not need parameters to be set: leave untouched.
581 
582   LstmCell(op_params, DimsToShape(input_dims), input_data,
583            DimsToShape(prev_activ_dims), prev_activ_data,
584            DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
585            bias_data, DimsToShape(prev_state_dims), prev_state_data,
586            DimsToShape(output_state_dims), output_state_data,
587            DimsToShape(output_activ_dims), output_activ_data,
588            DimsToShape(concat_temp_dims), concat_temp_data,
589            DimsToShape(activ_temp_dims), activ_temp_data);
590 }
591 
592 template <int StateIntegerBits>
LstmCell(const uint8 * input_data_uint8,const Dims<4> & input_dims,const uint8 * prev_activ_data_uint8,const Dims<4> & prev_activ_dims,const uint8 * weights_data_uint8,const Dims<4> & weights_dims,const int32 * bias_data_int32,const Dims<4> & bias_dims,const int16 * prev_state_data_int16,const Dims<4> & prev_state_dims,int16 * output_state_data_int16,const Dims<4> & output_state_dims,uint8 * output_activ_data_uint8,const Dims<4> & output_activ_dims,uint8 * concat_temp_data_uint8,const Dims<4> & concat_temp_dims,int16 * activ_temp_data_int16,const Dims<4> & activ_temp_dims,int32 weights_zero_point,int32 accum_multiplier,int accum_shift,gemmlowp::GemmContext * gemmlowp_context)593 void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
594               const uint8* prev_activ_data_uint8,
595               const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
596               const Dims<4>& weights_dims, const int32* bias_data_int32,
597               const Dims<4>& bias_dims, const int16* prev_state_data_int16,
598               const Dims<4>& prev_state_dims, int16* output_state_data_int16,
599               const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
600               const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
601               const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
602               const Dims<4>& activ_temp_dims, int32 weights_zero_point,
603               int32 accum_multiplier, int accum_shift,
604               gemmlowp::GemmContext* gemmlowp_context) {
605   tflite::LstmCellParams op_params;
606   op_params.weights_zero_point = weights_zero_point;
607   op_params.accum_multiplier = accum_multiplier;
608   op_params.accum_shift = accum_shift;
609 
610   LstmCell<StateIntegerBits>(
611       op_params, DimsToShape(input_dims), input_data_uint8,
612       DimsToShape(prev_activ_dims), prev_activ_data_uint8,
613       DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
614       bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
615       DimsToShape(output_state_dims), output_state_data_int16,
616       DimsToShape(output_activ_dims), output_activ_data_uint8,
617       DimsToShape(concat_temp_dims), concat_temp_data_uint8,
618       DimsToShape(activ_temp_dims), activ_temp_data_int16, gemmlowp_context);
619 }
620 
621 template <typename T>
BroadcastDiv(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)622 void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
623                   const T* input2_data, const Dims<4>& input2_dims,
624                   T output_activation_min, T output_activation_max,
625                   T* output_data, const Dims<4>& output_dims) {
626   tflite::ArithmeticParams op_params;
627   SetActivationParams(output_activation_min, output_activation_max, &op_params);
628 
629   BroadcastDivSlow(op_params, DimsToShape(input1_dims), input1_data,
630                    DimsToShape(input2_dims), input2_data,
631                    DimsToShape(output_dims), output_data);
632 }
633 
634 template <typename T>
Div(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)635 inline void Div(const T* input1_data, const Dims<4>& input1_dims,
636                 const T* input2_data, const Dims<4>& input2_dims,
637                 T output_activation_min, T output_activation_max,
638                 T* output_data, const Dims<4>& output_dims) {
639   tflite::ArithmeticParams op_params;
640   SetActivationParams(output_activation_min, output_activation_max, &op_params);
641 
642   Div(op_params, DimsToShape(input1_dims), input1_data,
643       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
644       output_data);
645 }
646 
647 template <FusedActivationFunctionType Ac, typename Scalar>
Concatenation(int concat_dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)648 inline void Concatenation(int concat_dim, const Scalar* const* input_data,
649                           const Dims<4>* const* input_dims, int inputs_count,
650                           Scalar* output_data, const Dims<4>& output_dims) {
651   // For now we don't have a model with a Concatenation with fused activation.
652   TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
653 
654   std::vector<RuntimeShape> input_shapes(inputs_count);
655   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
656   for (int i = 0; i < inputs_count; ++i) {
657     ShapeFromDims(*input_dims[i], &input_shapes[i]);
658     input_shapes_indirect[i] = &input_shapes[i];
659   }
660   tflite::ConcatenationParams op_params;
661   op_params.axis = 3 - concat_dim;
662   op_params.inputs_count = inputs_count;
663 
664   Concatenation(op_params, input_shapes_indirect.data(), input_data,
665                 DimsToShape(output_dims), output_data);
666 }
667 
Concatenation(int concat_dim,const uint8 * const * input_data,const Dims<4> * const * input_dims,const int32 * input_zeropoint,const float * input_scale,int inputs_count,uint8 * output_data,const Dims<4> & output_dims,const int32 output_zeropoint,const float output_scale)668 inline void Concatenation(int concat_dim, const uint8* const* input_data,
669                           const Dims<4>* const* input_dims,
670                           const int32* input_zeropoint,
671                           const float* input_scale, int inputs_count,
672                           uint8* output_data, const Dims<4>& output_dims,
673                           const int32 output_zeropoint,
674                           const float output_scale) {
675   std::vector<RuntimeShape> input_shapes(inputs_count);
676   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
677   for (int i = 0; i < inputs_count; ++i) {
678     ShapeFromDims(*input_dims[i], &input_shapes[i]);
679     input_shapes_indirect[i] = &input_shapes[i];
680   }
681   tflite::ConcatenationParams op_params;
682   op_params.axis = 3 - concat_dim;
683   op_params.input_zeropoint = input_zeropoint;
684   op_params.input_scale = input_scale;
685   op_params.inputs_count = inputs_count;
686   op_params.output_zeropoint = output_zeropoint;
687   op_params.output_scale = output_scale;
688 
689   ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
690                            DimsToShape(output_dims), output_data);
691 }
692 
693 template <FusedActivationFunctionType Ac, typename Scalar>
DepthConcatenation(const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)694 void DepthConcatenation(const Scalar* const* input_data,
695                         const Dims<4>* const* input_dims, int inputs_count,
696                         Scalar* output_data, const Dims<4>& output_dims) {
697   // For now we don't have a model with a Concatenation with fused activation.
698   TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
699 
700   std::vector<RuntimeShape> input_shapes(inputs_count);
701   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
702   for (int i = 0; i < inputs_count; ++i) {
703     ShapeFromDims(*input_dims[i], &input_shapes[i]);
704     input_shapes_indirect[i] = &input_shapes[i];
705   }
706   tflite::ConcatenationParams op_params;
707   op_params.inputs_count = inputs_count;
708 
709   DepthConcatenation(op_params, input_shapes_indirect.data(), input_data,
710                      DimsToShape(output_dims), output_data);
711 }
712 
713 template <typename Scalar>
TensorFlowSplit(const Scalar * input_data,const Dims<4> & input_dims,int axis,int outputs_count,Scalar * const * output_data,const Dims<4> * const * output_dims)714 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
715                      int axis, int outputs_count, Scalar* const* output_data,
716                      const Dims<4>* const* output_dims) {
717   std::vector<RuntimeShape> output_shapes(outputs_count);
718   std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
719   for (int i = 0; i < outputs_count; ++i) {
720     ShapeFromDims(*output_dims[i], &output_shapes[i]);
721     output_shapes_indirect[i] = &output_shapes[i];
722   }
723   tflite::SplitParams op_params;
724   op_params.axis = 3 - axis;
725   op_params.num_split = outputs_count;
726 
727   Split(op_params, DimsToShape(input_dims), input_data,
728         output_shapes_indirect.data(), output_data);
729 }
730 
731 template <FusedActivationFunctionType Ac, typename Scalar>
TensorFlowSplit(const Scalar * input_data,const Dims<4> & input_dims,int outputs_count,Scalar * const * output_data,const Dims<4> * const * output_dims)732 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
733                      int outputs_count, Scalar* const* output_data,
734                      const Dims<4>* const* output_dims) {
735   TFLITE_DCHECK_GE(outputs_count, 1);
736   for (int i = 0; i < outputs_count; i++) {
737     /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
738     /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
739     /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
740   }
741   // For now we don't have a model with a Split with fused activation.
742   TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
743 
744   TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
745                   output_data, output_dims);
746 }
747 
Softmax(const float * input_data,const RuntimeShape & input_shape,float beta,float * output_data,const RuntimeShape & output_shape)748 inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
749                     float beta, float* output_data,
750                     const RuntimeShape& output_shape) {
751   SoftmaxParams params;
752   params.beta = beta;
753   Softmax(params, input_shape, input_data, output_shape, output_data);
754 }
755 
Softmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)756 inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
757                     int32 input_beta_multiplier, int32 input_beta_left_shift,
758                     int diff_min, uint8* output_data,
759                     const RuntimeShape& output_shape) {
760   SoftmaxParams params;
761   params.input_multiplier = input_beta_multiplier;
762   params.input_left_shift = input_beta_left_shift;
763   params.diff_min = diff_min;
764   Softmax(params, input_shape, input_data, output_shape, output_data);
765 }
766 
LogSoftmax(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)767 inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
768                        float* output_data, const RuntimeShape& output_shape) {
769   SoftmaxParams params;
770   // No params currently used for float LogSoftmax.
771   LogSoftmax(params, input_shape, input_data, output_shape, output_data);
772 }
773 
LogSoftmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)774 inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
775                        int32 input_multiplier, int32 input_left_shift,
776                        int32 reverse_scaling_divisor,
777                        int32 reverse_scaling_right_shift, int diff_min,
778                        uint8* output_data, const RuntimeShape& output_shape) {
779   SoftmaxParams params;
780   params.input_multiplier = input_multiplier;
781   params.input_left_shift = input_left_shift;
782   params.reverse_scaling_divisor = reverse_scaling_divisor;
783   params.reverse_scaling_right_shift = reverse_scaling_right_shift;
784   params.diff_min = diff_min;
785   LogSoftmax(params, input_shape, input_data, output_shape, output_data);
786 }
787 
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)788 inline void Logistic(const LogisticParams& params,
789                      const RuntimeShape& input_shape, const uint8* input_data,
790                      const RuntimeShape& output_shape, uint8* output_data) {
791   const int32 input_zero_point = params.input_zero_point;
792   const int32 input_range_radius = params.input_range_radius;
793   const int32 input_multiplier = params.input_multiplier;
794   const int input_left_shift = params.input_left_shift;
795   const int flat_size = MatchingFlatSize(input_shape, output_shape);
796 
797   for (int i = 0; i < flat_size; i++) {
798     const uint8 input_val_u8 = input_data[i];
799     const int32 input_val_centered =
800         static_cast<int32>(input_val_u8) - input_zero_point;
801     uint8 output_val;
802     if (input_val_centered <= -input_range_radius) {
803       output_val = 0;
804     } else if (input_val_centered >= input_range_radius) {
805       output_val = 255;
806     } else {
807       const int32 input_val_rescaled =
808           MultiplyByQuantizedMultiplierGreaterThanOne(
809               input_val_centered, input_multiplier, input_left_shift);
810       using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
811       using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
812       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
813       const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
814       // Convert from Q0.31 to Q23.8.
815       using gemmlowp::RoundingDivideByPOT;
816       int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
817       if (output_val_s32 == 256) {
818         output_val_s32 = 255;
819       }
820       // Reinterpret as U0.8.
821       TFLITE_DCHECK_GE(output_val_s32, 0);
822       TFLITE_DCHECK_LE(output_val_s32, 255);
823       output_val = static_cast<uint8>(output_val_s32);
824     }
825     output_data[i] = output_val;
826   }
827 }
828 
Logistic(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)829 inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
830                      int32 input_zero_point, int32 input_range_radius,
831                      int32 input_multiplier, int input_left_shift,
832                      uint8* output_data, const RuntimeShape& output_shape) {
833   LogisticParams params;
834   params.input_zero_point = input_zero_point;
835   params.input_range_radius = input_range_radius;
836   params.input_multiplier = input_multiplier;
837   params.input_left_shift = input_left_shift;
838   Logistic(params, input_shape, input_data, output_shape, output_data);
839 }
840 
Logistic(const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)841 inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
842                      const RuntimeShape& output_shape, int16* output_data) {
843   LogisticParams params;
844   // No params currently needed by int16 Logistic.
845   Logistic(params, input_shape, input_data, output_shape, output_data);
846 }
847 
Tanh(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)848 inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
849                  int32 input_zero_point, int32 input_range_radius,
850                  int32 input_multiplier, int input_left_shift,
851                  uint8* output_data, const RuntimeShape& output_shape) {
852   TanhParams params;
853   params.input_zero_point = input_zero_point;
854   params.input_range_radius = input_range_radius;
855   params.input_multiplier = input_multiplier;
856   params.input_left_shift = input_left_shift;
857   Tanh(params, input_shape, input_data, output_shape, output_data);
858 }
859 
Tanh(const int16 * input_data,const RuntimeShape & input_shape,int input_left_shift,int16 * output_data,const RuntimeShape & output_shape)860 inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
861                  int input_left_shift, int16* output_data,
862                  const RuntimeShape& output_shape) {
863   TanhParams params;
864   params.input_left_shift = input_left_shift;
865   Tanh(params, input_shape, input_data, output_shape, output_data);
866 }
867 
Dequantize(const uint8 * input_data,const Dims<4> & input_dims,int32 zero_point,double scale,float * output_data,const Dims<4> & output_dims)868 inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
869                        int32 zero_point, double scale, float* output_data,
870                        const Dims<4>& output_dims) {
871   tflite::DequantizationParams op_params;
872   op_params.zero_point = zero_point;
873   op_params.scale = scale;
874 
875   Dequantize(op_params, DimsToShape(input_dims), input_data,
876              DimsToShape(output_dims), output_data);
877 }
878 
FakeQuant(const float * input_data,const Dims<4> & input_dims,float rmin,float rmax,int num_bits,float * output_data,const Dims<4> & output_dims)879 inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
880                       float rmin, float rmax, int num_bits, float* output_data,
881                       const Dims<4>& output_dims) {
882   tflite::FakeQuantParams op_params;
883   op_params.num_bits = num_bits;
884   op_params.minmax.min = rmin;
885   op_params.minmax.max = rmax;
886 
887   FakeQuant(op_params, DimsToShape(input_dims), input_data,
888             DimsToShape(output_dims), output_data);
889 }
890 
891 template <typename T>
Gather(const T * input_data,const Dims<4> & input_dims,int input_rank,const int32 * coords_data,const Dims<4> & coords_dims,T * output_data,const Dims<4> & output_dims)892 inline void Gather(const T* input_data, const Dims<4>& input_dims,
893                    int input_rank, const int32* coords_data,
894                    const Dims<4>& coords_dims, T* output_data,
895                    const Dims<4>& output_dims) {
896   tflite::GatherParams op_params;
897   op_params.axis = 4 - input_rank;
898 
899   Gather(op_params, DimsToShape(input_dims), input_data,
900          DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
901          output_data);
902 }
903 
LegacyReverseBits32(uint32 n)904 inline uint32 LegacyReverseBits32(uint32 n) {
905   n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
906   n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
907   n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
908   return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
909           ((n & 0xFF000000) >> 24));
910 }
911 
StridedSliceReverseIndices(tflite::StridedSliceParams * p)912 inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
913   TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
914   TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
915 
916   std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
917   std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
918   std::reverse(p->strides, p->strides + p->strides_count);
919 
920   p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
921                   (32 - p->start_indices_count);
922   p->ellipsis_mask =
923       LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
924       (32 - p->start_indices_count);
925   p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
926                 (32 - p->start_indices_count);
927   p->new_axis_mask =
928       LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
929       (32 - p->start_indices_count);
930   p->shrink_axis_mask =
931       LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
932       (32 - p->start_indices_count);
933 }
934 
935 template <typename T>
StridedSlice(const T * input_data,const Dims<4> & input_dims,int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides,T * output_data,const Dims<4> & output_dims)936 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
937                          int begin_mask, int end_mask, int shrink_axis_mask,
938                          const std::vector<int>& start_indices,
939                          const std::vector<int>& stop_indices,
940                          const std::vector<int>& strides, T* output_data,
941                          const Dims<4>& output_dims) {
942   TFLITE_DCHECK_EQ(start_indices.size(), 4);
943   auto op_params = strided_slice::BuildStridedSliceParams(
944       begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
945       strides);
946   StridedSliceReverseIndices(&op_params);
947 
948   StridedSlice(op_params, DimsToShape(input_dims), input_data,
949                DimsToShape(output_dims), output_data);
950 }
951 
952 template <typename T>
Mean(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & reduction_indices,T * output_data,const Dims<4> & output_dims)953 inline void Mean(const T* input_data, const Dims<4>& input_dims,
954                  const std::vector<int>& reduction_indices, T* output_data,
955                  const Dims<4>& output_dims) {
956   tflite::MeanParams op_params;
957   op_params.axis_count = reduction_indices.size();
958   for (int i = 0; i < op_params.axis_count; ++i) {
959     op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
960   }
961 
962   Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
963        output_data);
964 }
965 
966 template <typename T>
Transpose(const T * input,const Dims<4> & input_dims,T * output,const Dims<4> & output_dims,const int * permuted_axes)967 void Transpose(const T* input, const Dims<4>& input_dims, T* output,
968                const Dims<4>& output_dims, const int* permuted_axes) {
969   TransposeParams params;
970   params.perm_count = 4;
971   for (int i = 0; i < 4; ++i) {
972     params.perm[i] = 3 - permuted_axes[3 - i];
973   }
974   Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
975             output);
976 }
977 
978 template <typename T, ComparisonFn<T> F>
Comparison(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims)979 inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
980                        const T* input2_data, const Dims<4>& input2_dims,
981                        bool* output_data, const Dims<4>& output_dims) {
982   ComparisonParams op_params;
983   // No parameters needed.
984   ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
985                        DimsToShape(input2_dims), input2_data,
986                        DimsToShape(output_dims), output_data);
987 }
988 
989 template <typename T, ComparisonFn<int32> F>
Comparison(int left_shift,const T * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const T * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,bool * output_data,const Dims<4> & output_dims)990 inline void Comparison(int left_shift, const T* input1_data,
991                        const Dims<4>& input1_dims, int32 input1_offset,
992                        int32 input1_multiplier, int input1_shift,
993                        const T* input2_data, const Dims<4>& input2_dims,
994                        int32 input2_offset, int32 input2_multiplier,
995                        int input2_shift, bool* output_data,
996                        const Dims<4>& output_dims) {
997   tflite::ComparisonParams op_params;
998   op_params.left_shift = left_shift;
999   op_params.input1_offset = input1_offset;
1000   op_params.input1_multiplier = input1_multiplier;
1001   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1002   op_params.input1_shift = kReverseShift * input1_shift;
1003   op_params.input2_offset = input2_offset;
1004   op_params.input2_multiplier = input2_multiplier;
1005   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1006   op_params.input2_shift = kReverseShift * input2_shift;
1007 
1008   ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
1009                               DimsToShape(input2_dims), input2_data,
1010                               DimsToShape(output_dims), output_data);
1011 }
1012 
1013 template <typename T, ComparisonFn<T> F>
BroadcastComparison(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims)1014 inline void BroadcastComparison(const T* input1_data,
1015                                 const Dims<4>& input1_dims,
1016                                 const T* input2_data,
1017                                 const Dims<4>& input2_dims, bool* output_data,
1018                                 const Dims<4>& output_dims) {
1019   ComparisonParams op_params;
1020   // No parameters needed.
1021   BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
1022                                       input1_data, DimsToShape(input2_dims),
1023                                       input2_data, DimsToShape(output_dims),
1024                                       output_data);
1025 }
1026 
1027 template <typename T, ComparisonFn<int32> F>
BroadcastComparison(int left_shift,const T * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const T * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,bool * output_data,const Dims<4> & output_dims)1028 inline void BroadcastComparison(int left_shift, const T* input1_data,
1029                                 const Dims<4>& input1_dims, int32 input1_offset,
1030                                 int32 input1_multiplier, int input1_shift,
1031                                 const T* input2_data,
1032                                 const Dims<4>& input2_dims, int32 input2_offset,
1033                                 int32 input2_multiplier, int input2_shift,
1034                                 bool* output_data, const Dims<4>& output_dims) {
1035   ComparisonParams op_params;
1036 
1037   op_params.left_shift = left_shift;
1038   op_params.input1_offset = input1_offset;
1039   op_params.input1_multiplier = input1_multiplier;
1040   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1041   op_params.input1_shift = kReverseShift * input1_shift;
1042   op_params.input2_offset = input2_offset;
1043   op_params.input2_multiplier = input2_multiplier;
1044   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1045   op_params.input2_shift = kReverseShift * input2_shift;
1046 
1047   BroadcastComparison4DSlowWithScaling<T, F>(
1048       op_params, DimsToShape(input1_dims), input1_data,
1049       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1050       output_data);
1051 }
1052 
1053 #define TFLITE_LEGACY_COMPARISON_OP(name)                                     \
1054   template <typename T>                                                       \
1055   inline void name(const T* input1_data, const Dims<4>& input1_dims,          \
1056                    const T* input2_data, const Dims<4>& input2_dims,          \
1057                    bool* output_data, const Dims<4>& output_dims) {           \
1058     ruy::profiler::ScopeLabel label(#name);                                   \
1059     Comparison<T, name##Fn>(input1_data, input1_dims, input2_data,            \
1060                             input2_dims, output_data, output_dims);           \
1061   }                                                                           \
1062   template <typename T>                                                       \
1063   inline void name(                                                           \
1064       int left_shift, const T* input1_data, const Dims<4>& input1_dims,       \
1065       int32 input1_offset, int32 input1_multiplier, int input1_shift,         \
1066       const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset,  \
1067       int32 input2_multiplier, int input2_shift, bool* output_data,           \
1068       const Dims<4>& output_dims) {                                           \
1069     ruy::profiler::ScopeLabel label(#name "/8bit");                           \
1070     Comparison<T, name##Fn>(left_shift, input1_data, input1_dims,             \
1071                             input1_offset, input1_multiplier, input1_shift,   \
1072                             input2_data, input2_dims, input2_offset,          \
1073                             input2_multiplier, input2_shift, output_data,     \
1074                             output_dims);                                     \
1075   }                                                                           \
1076   template <typename T>                                                       \
1077   inline void Broadcast##name(                                                \
1078       const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
1079       const Dims<4>& input2_dims, bool* output_data,                          \
1080       const Dims<4>& output_dims) {                                           \
1081     ruy::profiler::ScopeLabel label("Broadcast" #name);                       \
1082     BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data,   \
1083                                      input2_dims, output_data, output_dims);  \
1084   }                                                                           \
1085   template <typename T>                                                       \
1086   inline void Broadcast##name(                                                \
1087       int left_shift, const T* input1_data, const Dims<4>& input1_dims,       \
1088       int32 input1_offset, int32 input1_multiplier, int input1_shift,         \
1089       const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset,  \
1090       int32 input2_multiplier, int input2_shift, bool* output_data,           \
1091       const Dims<4>& output_dims) {                                           \
1092     ruy::profiler::ScopeLabel label("Broadcast" #name "/8bit");               \
1093     BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims,    \
1094                                      input1_offset, input1_multiplier,        \
1095                                      input1_shift, input2_data, input2_dims,  \
1096                                      input2_offset, input2_multiplier,        \
1097                                      input2_shift, output_data, output_dims); \
1098   }
1099 TFLITE_LEGACY_COMPARISON_OP(Equal);
1100 TFLITE_LEGACY_COMPARISON_OP(NotEqual);
1101 TFLITE_LEGACY_COMPARISON_OP(Greater);
1102 TFLITE_LEGACY_COMPARISON_OP(GreaterEqual);
1103 TFLITE_LEGACY_COMPARISON_OP(Less);
1104 TFLITE_LEGACY_COMPARISON_OP(LessEqual);
1105 #undef TFLITE_LEGACY_COMPARISON_OP
1106 
1107 template <typename D, typename T>
Select(const D * input_condition_data,const Dims<4> & input_condition_dims,const T * input_x_data,const Dims<4> & input_x_dims,const T * input_y_data,const Dims<4> & input_y_dims,T * output_data,const Dims<4> & output_dims)1108 inline void Select(const D* input_condition_data,
1109                    const Dims<4>& input_condition_dims, const T* input_x_data,
1110                    const Dims<4>& input_x_dims, const T* input_y_data,
1111                    const Dims<4>& input_y_dims, T* output_data,
1112                    const Dims<4>& output_dims) {
1113   Select(DimsToShape(input_condition_dims), input_condition_data,
1114          DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
1115          input_y_data, DimsToShape(output_dims), output_data);
1116 }
1117 
1118 template <typename D, typename T>
RankOneSelect(const D * input_condition_data,const Dims<4> & input_condition_dims,const T * input_x_data,const Dims<4> & input_x_dims,const T * input_y_data,const Dims<4> & input_y_dims,T * output_data,const Dims<4> & output_dims)1119 inline void RankOneSelect(const D* input_condition_data,
1120                           const Dims<4>& input_condition_dims,
1121                           const T* input_x_data, const Dims<4>& input_x_dims,
1122                           const T* input_y_data, const Dims<4>& input_y_dims,
1123                           T* output_data, const Dims<4>& output_dims) {
1124   RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
1125                 DimsToShape(input_x_dims), input_x_data,
1126                 DimsToShape(input_y_dims), input_y_data,
1127                 DimsToShape(output_dims), output_data);
1128 }
1129 
1130 template <typename T, typename TI>
SparseToDense(const std::vector<std::vector<TI>> & indices,const T * values,T default_value,T * output_data,const Dims<4> & output_dims,bool value_is_scalar)1131 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
1132                           const T* values, T default_value, T* output_data,
1133                           const Dims<4>& output_dims, bool value_is_scalar) {
1134   SparseToDense(indices, values, default_value, value_is_scalar,
1135                 DimsToShape(output_dims), output_data);
1136 }
1137 
1138 template <typename Scalar>
Pack(int dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)1139 void Pack(int dim, const Scalar* const* input_data,
1140           const Dims<4>* const* input_dims, int inputs_count,
1141           Scalar* output_data, const Dims<4>& output_dims) {
1142   std::vector<RuntimeShape> input_shapes(inputs_count);
1143   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
1144   for (int i = 0; i < inputs_count; ++i) {
1145     ShapeFromDims(*input_dims[i], &input_shapes[i]);
1146     input_shapes_indirect[i] = &input_shapes[i];
1147   }
1148   tflite::PackParams op_params;
1149   op_params.axis = 3 - dim;
1150   op_params.inputs_count = inputs_count;
1151 
1152   Pack(op_params, input_shapes_indirect.data(), input_data,
1153        DimsToShape(output_dims), output_data);
1154 }
1155 
1156 template <typename Scalar>
Unpack(int axis,const Scalar * input_data,const Dims<4> & input_dims,int dimensions,int outputs_count,Scalar * const * output_datas,const Dims<4> & output_dims)1157 void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
1158             int dimensions, int outputs_count, Scalar* const* output_datas,
1159             const Dims<4>& output_dims) {
1160   tflite::UnpackParams op_params;
1161   op_params.axis = 3 - axis;
1162   op_params.num_split = outputs_count;
1163 
1164   Unpack(op_params, DimsToShape(input_dims), input_data,
1165          DimsToShape(output_dims), output_datas);
1166 }
1167 
1168 template <typename Scalar>
Pack(int dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,const int32 * input_zeropoint,const float * input_scale,int inputs_count,Scalar * output_data,const Dims<4> & output_dims,const int32 output_zeropoint,const float output_scale)1169 void Pack(int dim, const Scalar* const* input_data,
1170           const Dims<4>* const* input_dims, const int32* input_zeropoint,
1171           const float* input_scale, int inputs_count, Scalar* output_data,
1172           const Dims<4>& output_dims, const int32 output_zeropoint,
1173           const float output_scale) {
1174   std::vector<RuntimeShape> input_shapes(inputs_count);
1175   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
1176   for (int i = 0; i < inputs_count; ++i) {
1177     ShapeFromDims(*input_dims[i], &input_shapes[i]);
1178     input_shapes_indirect[i] = &input_shapes[i];
1179   }
1180   tflite::PackParams op_params;
1181   op_params.axis = 3 - dim;
1182   op_params.input_zeropoint = input_zeropoint;
1183   op_params.input_scale = input_scale;
1184   op_params.inputs_count = inputs_count;
1185   op_params.output_zeropoint = output_zeropoint;
1186   op_params.output_scale = output_scale;
1187 
1188   PackWithScaling(op_params, input_shapes_indirect.data(), input_data,
1189                   DimsToShape(output_dims), output_data);
1190 }
1191 
1192 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)1193 void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
1194                      float* output_data, const RuntimeShape& output_shape) {
1195   static_assert(Ac == FusedActivationFunctionType::kNone, "");
1196   tflite::L2NormalizationParams op_params;
1197   // No params need to be set for float.
1198 
1199   L2Normalization(op_params, input_shape, input_data, output_shape,
1200                   output_data);
1201 }
1202 
L2Normalization(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,uint8 * output_data,const RuntimeShape & output_shape)1203 inline void L2Normalization(const uint8* input_data,
1204                             const RuntimeShape& input_shape,
1205                             int32 input_zero_point, uint8* output_data,
1206                             const RuntimeShape& output_shape) {
1207   tflite::L2NormalizationParams op_params;
1208   op_params.input_zero_point = input_zero_point;
1209 
1210   L2Normalization(op_params, input_shape, input_data, output_shape,
1211                   output_data);
1212 }
1213 
1214 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1215 void L2Normalization(const float* input_data, const Dims<4>& input_dims,
1216                      float* output_data, const Dims<4>& output_dims) {
1217   L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
1218                       DimsToShape(output_dims));
1219 }
1220 
L2Normalization(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,uint8 * output_data,const Dims<4> & output_dims)1221 inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
1222                             int32 input_zero_point, uint8* output_data,
1223                             const Dims<4>& output_dims) {
1224   L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
1225                   output_data, DimsToShape(output_dims));
1226 }
1227 
Relu(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1228 inline void Relu(const float* input_data, const Dims<4>& input_dims,
1229                  float* output_data, const Dims<4>& output_dims) {
1230   Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1231        output_data);
1232 }
1233 
Relu1(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1234 inline void Relu1(const float* input_data, const Dims<4>& input_dims,
1235                   float* output_data, const Dims<4>& output_dims) {
1236   Relu1(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1237         output_data);
1238 }
1239 
Relu6(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1240 inline void Relu6(const float* input_data, const Dims<4>& input_dims,
1241                   float* output_data, const Dims<4>& output_dims) {
1242   Relu6(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1243         output_data);
1244 }
1245 
ReluX(uint8 min_value,uint8 max_value,const uint8 * input_data,const RuntimeShape & input_shape,uint8 * output_data,const RuntimeShape & output_shape)1246 inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
1247                   const RuntimeShape& input_shape, uint8* output_data,
1248                   const RuntimeShape& output_shape) {
1249   tflite::ActivationParams params;
1250   params.quantized_activation_max = max_value;
1251   params.quantized_activation_min = min_value;
1252   ReluX(params, input_shape, input_data, output_shape, output_data);
1253 }
1254 
1255 template <FusedActivationFunctionType Ac>
Add(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1256 inline void Add(int left_shift, const uint8* input1_data,
1257                 const Dims<4>& input1_dims, int32 input1_offset,
1258                 int32 input1_multiplier, int input1_shift,
1259                 const uint8* input2_data, const Dims<4>& input2_dims,
1260                 int32 input2_offset, int32 input2_multiplier, int input2_shift,
1261                 int32 output_offset, int32 output_multiplier, int output_shift,
1262                 int32 output_activation_min, int32 output_activation_max,
1263                 uint8* output_data, const Dims<4>& output_dims) {
1264   constexpr int kReverseShift = -1;
1265   static_assert(Ac == FusedActivationFunctionType::kNone ||
1266                     Ac == FusedActivationFunctionType::kRelu ||
1267                     Ac == FusedActivationFunctionType::kRelu6 ||
1268                     Ac == FusedActivationFunctionType::kRelu1,
1269                 "");
1270   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1271   if (Ac == FusedActivationFunctionType::kNone) {
1272     TFLITE_DCHECK_EQ(output_activation_min, 0);
1273     TFLITE_DCHECK_EQ(output_activation_max, 255);
1274   }
1275 
1276   tflite::ArithmeticParams op_params;
1277   op_params.left_shift = left_shift;
1278   op_params.input1_offset = input1_offset;
1279   op_params.input1_multiplier = input1_multiplier;
1280   op_params.input1_shift = kReverseShift * input1_shift;
1281   op_params.input2_offset = input2_offset;
1282   op_params.input2_multiplier = input2_multiplier;
1283   op_params.input2_shift = kReverseShift * input2_shift;
1284   op_params.output_offset = output_offset;
1285   op_params.output_multiplier = output_multiplier;
1286   op_params.output_shift = kReverseShift * output_shift;
1287   op_params.quantized_activation_min = output_activation_min;
1288   op_params.quantized_activation_max = output_activation_max;
1289   Add(op_params, DimsToShape(input1_dims), input1_data,
1290       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1291       output_data);
1292 }
1293 
1294 template <FusedActivationFunctionType Ac>
Add(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)1295 void Add(const int32* input1_data, const Dims<4>& input1_dims,
1296          const int32* input2_data, const Dims<4>& input2_dims,
1297          int32* output_data, const Dims<4>& output_dims) {
1298   ruy::profiler::ScopeLabel label("Add/int32");
1299   TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
1300 
1301   tflite::ArithmeticParams op_params;
1302   op_params.quantized_activation_min = std::numeric_limits<int32>::min();
1303   op_params.quantized_activation_max = std::numeric_limits<int32>::max();
1304   Add(op_params, DimsToShape(input1_dims), input1_data,
1305       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1306       output_data);
1307 }
1308 
1309 template <FusedActivationFunctionType Ac>
BroadcastAdd(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1310 inline void BroadcastAdd(int left_shift, const uint8* input1_data,
1311                          const Dims<4>& input1_dims, int32 input1_offset,
1312                          int32 input1_multiplier, int input1_shift,
1313                          const uint8* input2_data, const Dims<4>& input2_dims,
1314                          int32 input2_offset, int32 input2_multiplier,
1315                          int input2_shift, int32 output_offset,
1316                          int32 output_multiplier, int output_shift,
1317                          int32 output_activation_min,
1318                          int32 output_activation_max, uint8* output_data,
1319                          const Dims<4>& output_dims) {
1320   constexpr int kReverseShift = -1;
1321   static_assert(Ac == FusedActivationFunctionType::kNone ||
1322                     Ac == FusedActivationFunctionType::kRelu ||
1323                     Ac == FusedActivationFunctionType::kRelu6 ||
1324                     Ac == FusedActivationFunctionType::kRelu1,
1325                 "");
1326   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1327   if (Ac == FusedActivationFunctionType::kNone) {
1328     TFLITE_DCHECK_EQ(output_activation_min, 0);
1329     TFLITE_DCHECK_EQ(output_activation_max, 255);
1330   }
1331 
1332   tflite::ArithmeticParams op_params;
1333   op_params.left_shift = left_shift;
1334   op_params.input1_offset = input1_offset;
1335   op_params.input1_multiplier = input1_multiplier;
1336   op_params.input1_shift = kReverseShift * input1_shift;
1337   op_params.input2_offset = input2_offset;
1338   op_params.input2_multiplier = input2_multiplier;
1339   op_params.input2_shift = kReverseShift * input2_shift;
1340   op_params.output_offset = output_offset;
1341   op_params.output_multiplier = output_multiplier;
1342   op_params.output_shift = kReverseShift * output_shift;
1343   op_params.quantized_activation_min = output_activation_min;
1344   op_params.quantized_activation_max = output_activation_max;
1345   BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1346                      DimsToShape(input2_dims), input2_data,
1347                      DimsToShape(output_dims), output_data);
1348 }
1349 
1350 template <FusedActivationFunctionType Ac>
Add(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1351 void Add(const float* input1_data, const Dims<4>& input1_dims,
1352          const float* input2_data, const Dims<4>& input2_dims,
1353          float* output_data, const Dims<4>& output_dims) {
1354   float output_activation_min, output_activation_max;
1355   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1356 
1357   tflite::ArithmeticParams op_params;
1358   op_params.float_activation_min = output_activation_min;
1359   op_params.float_activation_max = output_activation_max;
1360   Add(op_params, DimsToShape(input1_dims), input1_data,
1361       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1362       output_data);
1363 }
1364 
1365 template <typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1366 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
1367                   const T* input2_data, const Dims<4>& input2_dims,
1368                   T output_activation_min, T output_activation_max,
1369                   T* output_data, const Dims<4>& output_dims) {
1370   tflite::ArithmeticParams op_params;
1371   op_params.float_activation_min = output_activation_min;
1372   op_params.float_activation_max = output_activation_max;
1373   BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1374                      DimsToShape(input2_dims), input2_data,
1375                      DimsToShape(output_dims), output_data);
1376 }
1377 
1378 template <FusedActivationFunctionType Ac>
BroadcastAddFivefold(int y0,int y1,int y2,int y3,int y4,int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1379 inline void BroadcastAddFivefold(
1380     int y0, int y1, int y2, int y3, int y4, int left_shift,
1381     const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
1382     int32 input1_multiplier, int input1_shift, const uint8* input2_data,
1383     const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
1384     int input2_shift, int32 output_offset, int32 output_multiplier,
1385     int output_shift, int32 output_activation_min, int32 output_activation_max,
1386     uint8* output_data, const Dims<4>& output_dims) {
1387   constexpr int kReverseShift = -1;
1388   static_assert(Ac == FusedActivationFunctionType::kNone ||
1389                     Ac == FusedActivationFunctionType::kRelu ||
1390                     Ac == FusedActivationFunctionType::kRelu6 ||
1391                     Ac == FusedActivationFunctionType::kRelu1,
1392                 "");
1393   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1394   if (Ac == FusedActivationFunctionType::kNone) {
1395     TFLITE_DCHECK_EQ(output_activation_min, 0);
1396     TFLITE_DCHECK_EQ(output_activation_max, 255);
1397   }
1398   tflite::ArithmeticParams op_params;
1399   op_params.broadcast_category =
1400       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
1401   op_params.left_shift = left_shift;
1402   op_params.input1_offset = input1_offset;
1403   op_params.input1_multiplier = input1_multiplier;
1404   op_params.input1_shift = kReverseShift * input1_shift;
1405   op_params.input2_offset = input2_offset;
1406   op_params.input2_multiplier = input2_multiplier;
1407   op_params.input2_shift = kReverseShift * input2_shift;
1408   op_params.output_offset = output_offset;
1409   op_params.output_multiplier = output_multiplier;
1410   op_params.output_shift = kReverseShift * output_shift;
1411   op_params.quantized_activation_min = output_activation_min;
1412   op_params.quantized_activation_max = output_activation_max;
1413   op_params.broadcast_shape[4] = y0;
1414   op_params.broadcast_shape[3] = y1;
1415   op_params.broadcast_shape[2] = y2;
1416   op_params.broadcast_shape[1] = y3;
1417   op_params.broadcast_shape[0] = y4;
1418   BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
1419                        DimsToShape(input2_dims), input2_data,
1420                        DimsToShape(output_dims), output_data);
1421 }
1422 
1423 // legacy, for compatibility with old checked-in code
1424 template <FusedActivationFunctionType Ac, typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1425 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
1426                   const T* input2_data, const Dims<4>& input2_dims,
1427                   T* output_data, const Dims<4>& output_dims) {
1428   T output_activation_min, output_activation_max;
1429   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1430 
1431   BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
1432                output_activation_min, output_activation_max, output_data,
1433                output_dims);
1434 }
1435 
1436 template <FusedActivationFunctionType Ac>
Add(const int16 * input1_data,const Dims<4> & input1_dims,int input1_shift,const int16 * input2_data,const Dims<4> & input2_dims,int input2_shift,int16 output_activation_min,int16 output_activation_max,int16 * output_data,const Dims<4> & output_dims)1437 inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
1438                 int input1_shift, const int16* input2_data,
1439                 const Dims<4>& input2_dims, int input2_shift,
1440                 int16 output_activation_min, int16 output_activation_max,
1441                 int16* output_data, const Dims<4>& output_dims) {
1442   static_assert(Ac == FusedActivationFunctionType::kNone ||
1443                     Ac == FusedActivationFunctionType::kRelu ||
1444                     Ac == FusedActivationFunctionType::kRelu6 ||
1445                     Ac == FusedActivationFunctionType::kRelu1,
1446                 "");
1447   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1448   if (Ac == FusedActivationFunctionType::kNone) {
1449     TFLITE_DCHECK_EQ(output_activation_min, -32768);
1450     TFLITE_DCHECK_EQ(output_activation_max, 32767);
1451   }
1452 
1453   tflite::ArithmeticParams op_params;
1454   op_params.input1_shift = kReverseShift * input1_shift;
1455   op_params.input2_shift = kReverseShift * input2_shift;
1456   op_params.quantized_activation_min = output_activation_min;
1457   op_params.quantized_activation_max = output_activation_max;
1458   Add(op_params, DimsToShape(input1_dims), input1_data,
1459       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1460       output_data);
1461 }
1462 
Sub(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1463 inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
1464                 const float* input2_data, const Dims<4>& input2_dims,
1465                 float* output_data, const Dims<4>& output_dims) {
1466   float output_activation_min, output_activation_max;
1467   GetActivationMinMax(FusedActivationFunctionType::kNone,
1468                       &output_activation_min, &output_activation_max);
1469   tflite::ArithmeticParams op_params;
1470   op_params.float_activation_min = output_activation_min;
1471   op_params.float_activation_max = output_activation_max;
1472   Sub(op_params, DimsToShape(input1_dims), input1_data,
1473       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1474       output_data);
1475 }
1476 
1477 template <typename T>
Sub(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1478 void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
1479          const Dims<4>& input2_dims, T* output_data,
1480          const Dims<4>& output_dims) {
1481   tflite::ArithmeticParams op_params;
1482   op_params.quantized_activation_min = std::numeric_limits<T>::min();
1483   op_params.quantized_activation_max = std::numeric_limits<T>::max();
1484   Sub(op_params, DimsToShape(input1_dims), input1_data,
1485       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1486       output_data);
1487 }
1488 
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1489 inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
1490                         int stride_width, int stride_height, int pad_width,
1491                         int pad_height, int kwidth, int kheight,
1492                         float output_activation_min,
1493                         float output_activation_max, float* output_data,
1494                         const Dims<4>& output_dims) {
1495   tflite::PoolParams params;
1496   params.stride_height = stride_height;
1497   params.stride_width = stride_width;
1498   params.filter_height = kheight;
1499   params.filter_width = kwidth;
1500   params.padding_values.height = pad_height;
1501   params.padding_values.width = pad_width;
1502   params.float_activation_min = output_activation_min;
1503   params.float_activation_max = output_activation_max;
1504   AveragePool(params, DimsToShape(input_dims), input_data,
1505               DimsToShape(output_dims), output_data);
1506 }
1507 
1508 // Transitional version that will be moved shortly to legacy_reference_ops, as
1509 // part of RuntimeShape revisions.
BroadcastMul4DSlow(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1510 inline void BroadcastMul4DSlow(const uint8* input1_data,
1511                                const Dims<4>& input1_dims, int32 input1_offset,
1512                                const uint8* input2_data,
1513                                const Dims<4>& input2_dims, int32 input2_offset,
1514                                int32 output_offset, int32 output_multiplier,
1515                                int output_shift, int32 output_activation_min,
1516                                int32 output_activation_max, uint8* output_data,
1517                                const Dims<4>& output_dims) {
1518   tflite::ArithmeticParams op_params;
1519   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1520   op_params.input1_offset = input1_offset;
1521   op_params.input2_offset = input2_offset;
1522   op_params.output_offset = output_offset;
1523   op_params.output_multiplier = output_multiplier;
1524   op_params.output_shift = output_shift;
1525 
1526   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1527                      DimsToShape(input2_dims), input2_data,
1528                      DimsToShape(output_dims), output_data);
1529 }
1530 
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1531 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
1532                          int32 input1_offset, const uint8* input2_data,
1533                          const Dims<4>& input2_dims, int32 input2_offset,
1534                          int32 output_offset, int32 output_multiplier,
1535                          int output_shift, int32 output_activation_min,
1536                          int32 output_activation_max, uint8* output_data,
1537                          const Dims<4>& output_dims) {
1538   BroadcastMul4DSlow(
1539       input1_data, input1_dims, input1_offset, input2_data, input2_dims,
1540       input2_offset, output_offset, output_multiplier,
1541       //
1542       kReverseShift * output_shift,
1543       //
1544       output_activation_min, output_activation_max, output_data, output_dims);
1545 }
1546 
1547 // legacy, for compatibility with old checked-in code
1548 template <FusedActivationFunctionType Ac>
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1549 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
1550                          int32 input1_offset, const uint8* input2_data,
1551                          const Dims<4>& input2_dims, int32 input2_offset,
1552                          int32 output_offset, int32 output_multiplier,
1553                          int output_shift, int32 output_activation_min,
1554                          int32 output_activation_max, uint8* output_data,
1555                          const Dims<4>& output_dims) {
1556   BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
1557                input2_dims, input2_offset, output_offset, output_multiplier,
1558                output_shift, output_activation_min, output_activation_max,
1559                output_data, output_dims);
1560 }
1561 
1562 // legacy, for compatibility with old checked-in code
1563 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)1564 void AveragePool(const float* input_data, const Dims<4>& input_dims,
1565                  int stride_width, int stride_height, int pad_width,
1566                  int pad_height, int kwidth, int kheight, float* output_data,
1567                  const Dims<4>& output_dims) {
1568   float output_activation_min, output_activation_max;
1569   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1570 
1571   AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
1572               pad_height, kwidth, kheight, output_activation_min,
1573               output_activation_max, output_data, output_dims);
1574 }
1575 
1576 // legacy, for compatibility with old checked-in code
1577 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1578 void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
1579                  int pad_width, int pad_height, int filter_width,
1580                  int filter_height, float* output_data,
1581                  const Dims<4>& output_dims) {
1582   AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1583                   filter_width, filter_height, output_data, output_dims);
1584 }
1585 
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1586 inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
1587                         int stride_width, int stride_height, int pad_width,
1588                         int pad_height, int filter_width, int filter_height,
1589                         int32 output_activation_min,
1590                         int32 output_activation_max, uint8* output_data,
1591                         const Dims<4>& output_dims) {
1592   tflite::PoolParams params;
1593   params.stride_height = stride_height;
1594   params.stride_width = stride_width;
1595   params.filter_height = filter_height;
1596   params.filter_width = filter_width;
1597   params.padding_values.height = pad_height;
1598   params.padding_values.width = pad_width;
1599   params.quantized_activation_min = output_activation_min;
1600   params.quantized_activation_max = output_activation_max;
1601   AveragePool(params, DimsToShape(input_dims), input_data,
1602               DimsToShape(output_dims), output_data);
1603 }
1604 
1605 // legacy, for compatibility with old checked-in code
1606 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1607 void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
1608                  int stride_width, int stride_height, int pad_width,
1609                  int pad_height, int filter_width, int filter_height,
1610                  int32 output_activation_min, int32 output_activation_max,
1611                  uint8* output_data, const Dims<4>& output_dims) {
1612   static_assert(Ac == FusedActivationFunctionType::kNone ||
1613                     Ac == FusedActivationFunctionType::kRelu ||
1614                     Ac == FusedActivationFunctionType::kRelu6 ||
1615                     Ac == FusedActivationFunctionType::kRelu1,
1616                 "");
1617   if (Ac == FusedActivationFunctionType::kNone) {
1618     TFLITE_DCHECK_EQ(output_activation_min, 0);
1619     TFLITE_DCHECK_EQ(output_activation_max, 255);
1620   }
1621   AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
1622               pad_height, filter_width, filter_height, output_activation_min,
1623               output_activation_max, output_data, output_dims);
1624 }
1625 
1626 // legacy, for compatibility with old checked-in code
1627 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1628 void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
1629                  int pad_width, int pad_height, int filter_width,
1630                  int filter_height, int32 output_activation_min,
1631                  int32 output_activation_max, uint8* output_data,
1632                  const Dims<4>& output_dims) {
1633   AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1634                   filter_width, filter_height, output_activation_min,
1635                   output_activation_max, output_data, output_dims);
1636 }
1637 
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1638 inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
1639                     int stride_width, int stride_height, int pad_width,
1640                     int pad_height, int kwidth, int kheight,
1641                     float output_activation_min, float output_activation_max,
1642                     float* output_data, const Dims<4>& output_dims) {
1643   tflite::PoolParams params;
1644   params.stride_height = stride_height;
1645   params.stride_width = stride_width;
1646   params.filter_height = kheight;
1647   params.filter_width = kwidth;
1648   params.padding_values.height = pad_height;
1649   params.padding_values.width = pad_width;
1650   params.float_activation_min = output_activation_min;
1651   params.float_activation_max = output_activation_max;
1652   MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1653           output_data);
1654 }
1655 
1656 // legacy, for compatibility with old checked-in code
1657 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)1658 void MaxPool(const float* input_data, const Dims<4>& input_dims,
1659              int stride_width, int stride_height, int pad_width, int pad_height,
1660              int kwidth, int kheight, float* output_data,
1661              const Dims<4>& output_dims) {
1662   float output_activation_min, output_activation_max;
1663   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1664   MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
1665           pad_height, kwidth, kheight, output_activation_min,
1666           output_activation_max, output_data, output_dims);
1667 }
1668 
1669 // legacy, for compatibility with old checked-in code
1670 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1671 void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
1672              int pad_width, int pad_height, int filter_width, int filter_height,
1673              float* output_data, const Dims<4>& output_dims) {
1674   MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1675               filter_width, filter_height, output_data, output_dims);
1676 }
1677 
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1678 inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
1679                     int stride_width, int stride_height, int pad_width,
1680                     int pad_height, int filter_width, int filter_height,
1681                     int32 output_activation_min, int32 output_activation_max,
1682                     uint8* output_data, const Dims<4>& output_dims) {
1683   PoolParams params;
1684   params.stride_height = stride_height;
1685   params.stride_width = stride_width;
1686   params.filter_height = filter_height;
1687   params.filter_width = filter_width;
1688   params.padding_values.height = pad_height;
1689   params.padding_values.width = pad_width;
1690   params.quantized_activation_min = output_activation_min;
1691   params.quantized_activation_max = output_activation_max;
1692   MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1693           output_data);
1694 }
1695 
1696 // legacy, for compatibility with old checked-in code
1697 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1698 void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
1699              int stride_width, int stride_height, int pad_width, int pad_height,
1700              int filter_width, int filter_height, int32 output_activation_min,
1701              int32 output_activation_max, uint8* output_data,
1702              const Dims<4>& output_dims) {
1703   static_assert(Ac == FusedActivationFunctionType::kNone ||
1704                     Ac == FusedActivationFunctionType::kRelu ||
1705                     Ac == FusedActivationFunctionType::kRelu6 ||
1706                     Ac == FusedActivationFunctionType::kRelu1,
1707                 "");
1708   if (Ac == FusedActivationFunctionType::kNone) {
1709     TFLITE_DCHECK_EQ(output_activation_min, 0);
1710     TFLITE_DCHECK_EQ(output_activation_max, 255);
1711   }
1712   MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
1713           pad_height, filter_width, filter_height, output_activation_min,
1714           output_activation_max, output_data, output_dims);
1715 }
1716 
1717 // legacy, for compatibility with old checked-in code
1718 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1719 void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
1720              int pad_width, int pad_height, int filter_width, int filter_height,
1721              int32 output_activation_min, int32 output_activation_max,
1722              uint8* output_data, const Dims<4>& output_dims) {
1723   MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1724               filter_width, filter_height, output_activation_min,
1725               output_activation_max, output_data, output_dims);
1726 }
1727 
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1728 inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
1729                    int stride_width, int stride_height, int pad_width,
1730                    int pad_height, int filter_width, int filter_height,
1731                    float output_activation_min, float output_activation_max,
1732                    float* output_data, const Dims<4>& output_dims) {
1733   PoolParams params;
1734   params.stride_height = stride_height;
1735   params.stride_width = stride_width;
1736   params.filter_height = filter_height;
1737   params.filter_width = filter_width;
1738   params.padding_values.height = pad_height;
1739   params.padding_values.width = pad_width;
1740   params.float_activation_min = output_activation_min;
1741   params.float_activation_max = output_activation_max;
1742   L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1743          output_data);
1744 }
1745 
1746 // legacy, for compatibility with old checked-in code
1747 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1748 void L2Pool(const float* input_data, const Dims<4>& input_dims,
1749             int stride_width, int stride_height, int pad_width, int pad_height,
1750             int filter_width, int filter_height, float* output_data,
1751             const Dims<4>& output_dims) {
1752   float output_activation_min, output_activation_max;
1753   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1754   L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
1755          pad_height, filter_width, filter_height, output_activation_min,
1756          output_activation_max, output_data, output_dims);
1757 }
1758 
1759 // legacy, for compatibility with old checked-in code
1760 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1761 void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
1762             int pad_width, int pad_height, int filter_width, int filter_height,
1763             float* output_data, const Dims<4>& output_dims) {
1764   L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1765              filter_width, filter_height, output_data, output_dims);
1766 }
1767 
Softmax(const float * input_data,const Dims<4> & input_dims,float beta,float * output_data,const Dims<4> & output_dims)1768 inline void Softmax(const float* input_data, const Dims<4>& input_dims,
1769                     float beta, float* output_data,
1770                     const Dims<4>& output_dims) {
1771   Softmax(input_data, DimsToShape(input_dims), beta, output_data,
1772           DimsToShape(output_dims));
1773 }
1774 
Softmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)1775 inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
1776                     int32 input_beta_multiplier, int32 input_beta_left_shift,
1777                     int diff_min, uint8* output_data,
1778                     const Dims<4>& output_dims) {
1779   Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
1780           input_beta_left_shift, diff_min, output_data,
1781           DimsToShape(output_dims));
1782 }
1783 
LogSoftmax(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1784 inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
1785                        float* output_data, const Dims<4>& output_dims) {
1786   LogSoftmax(input_data, DimsToShape(input_dims), output_data,
1787              DimsToShape(output_dims));
1788 }
1789 
LogSoftmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)1790 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
1791                        int32 input_multiplier, int32 input_left_shift,
1792                        int32 reverse_scaling_divisor,
1793                        int32 reverse_scaling_right_shift, int diff_min,
1794                        uint8* output_data, const Dims<4>& output_dims) {
1795   LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
1796              input_left_shift, reverse_scaling_divisor,
1797              reverse_scaling_right_shift, diff_min, output_data,
1798              DimsToShape(output_dims));
1799 }
1800 
Logistic(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1801 inline void Logistic(const float* input_data, const Dims<4>& input_dims,
1802                      float* output_data, const Dims<4>& output_dims) {
1803   Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1804            output_data);
1805 }
1806 
Logistic(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)1807 inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
1808                      int32 input_zero_point, int32 input_range_radius,
1809                      int32 input_multiplier, int input_left_shift,
1810                      uint8* output_data, const Dims<4>& output_dims) {
1811   Logistic(input_data, DimsToShape(input_dims), input_zero_point,
1812            input_range_radius, input_multiplier, input_left_shift, output_data,
1813            DimsToShape(output_dims));
1814 }
1815 
Logistic(const int16 * input_data,const Dims<4> & input_dims,int16 * output_data,const Dims<4> & output_dims)1816 inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
1817                      int16* output_data, const Dims<4>& output_dims) {
1818   Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1819            output_data);
1820 }
1821 
Tanh(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1822 inline void Tanh(const float* input_data, const Dims<4>& input_dims,
1823                  float* output_data, const Dims<4>& output_dims) {
1824   Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1825        output_data);
1826 }
1827 
Tanh(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)1828 inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
1829                  int32 input_zero_point, int32 input_range_radius,
1830                  int32 input_multiplier, int input_left_shift,
1831                  uint8* output_data, const Dims<4>& output_dims) {
1832   Tanh(input_data, DimsToShape(input_dims), input_zero_point,
1833        input_range_radius, input_multiplier, input_left_shift, output_data,
1834        DimsToShape(output_dims));
1835 }
1836 
Tanh(const int16 * input_data,const Dims<4> & input_dims,int input_left_shift,int16 * output_data,const Dims<4> & output_dims)1837 inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
1838                  int input_left_shift, int16* output_data,
1839                  const Dims<4>& output_dims) {
1840   Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
1841        DimsToShape(output_dims));
1842 }
1843 
1844 template <typename T>
DepthToSpace(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)1845 inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
1846                          int block_size, T* output_data,
1847                          const Dims<4>& output_dims) {
1848   tflite::DepthToSpaceParams op_params;
1849   op_params.block_size = block_size;
1850 
1851   DepthToSpace(op_params, DimsToShape(input_dims), input_data,
1852                DimsToShape(output_dims), output_data);
1853 }
1854 
1855 template <typename T>
SpaceToDepth(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)1856 inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
1857                          int block_size, T* output_data,
1858                          const Dims<4>& output_dims) {
1859   tflite::SpaceToDepthParams op_params;
1860   op_params.block_size = block_size;
1861 
1862   SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
1863                DimsToShape(output_dims), output_data);
1864 }
1865 
1866 template <typename T>
Mul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1867 inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
1868                 const T* input2_data, const Dims<4>& input2_dims,
1869                 T output_activation_min, T output_activation_max,
1870                 T* output_data, const Dims<4>& output_dims) {
1871   tflite::ArithmeticParams op_params;
1872   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1873 
1874   Mul(op_params, DimsToShape(input1_dims), input1_data,
1875       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1876       output_data);
1877 }
1878 
1879 // legacy, for compatibility with old checked-in code
1880 template <FusedActivationFunctionType Ac>
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1881 void Mul(const float* input1_data, const Dims<4>& input1_dims,
1882          const float* input2_data, const Dims<4>& input2_dims,
1883          float* output_data, const Dims<4>& output_dims) {
1884   float output_activation_min, output_activation_max;
1885   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1886 
1887   tflite::ArithmeticParams op_params;
1888   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1889 
1890   Mul(op_params, DimsToShape(input1_dims), input1_data,
1891       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1892       output_data);
1893 }
1894 
1895 template <typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1896 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
1897                   const T* input2_data, const Dims<4>& input2_dims,
1898                   T output_activation_min, T output_activation_max,
1899                   T* output_data, const Dims<4>& output_dims) {
1900   tflite::ArithmeticParams op_params;
1901   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1902 
1903   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1904                      DimsToShape(input2_dims), input2_data,
1905                      DimsToShape(output_dims), output_data);
1906 }
1907 
1908 // legacy, for compatibility with old checked-in code
1909 template <FusedActivationFunctionType Ac, typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1910 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
1911                   const T* input2_data, const Dims<4>& input2_dims,
1912                   T* output_data, const Dims<4>& output_dims) {
1913   T output_activation_min, output_activation_max;
1914   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1915 
1916   tflite::ArithmeticParams op_params;
1917   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1918 
1919   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1920                      DimsToShape(input2_dims), input2_data,
1921                      DimsToShape(output_dims), output_data);
1922 }
1923 
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int16 * output_data,const Dims<4> & output_dims)1924 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
1925                 const int16* input2_data, const Dims<4>& input2_dims,
1926                 int16* output_data, const Dims<4>& output_dims) {
1927   tflite::ArithmeticParams op_params;
1928   // No params in this version.
1929 
1930   Mul(op_params, DimsToShape(input1_dims), input1_data,
1931       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1932       output_data);
1933 }
1934 
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int32 output_offset,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1935 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
1936                 const int16* input2_data, const Dims<4>& input2_dims,
1937                 int32 output_offset, int32 output_activation_min,
1938                 int32 output_activation_max, uint8* output_data,
1939                 const Dims<4>& output_dims) {
1940   tflite::ArithmeticParams op_params;
1941   op_params.quantized_activation_min = output_activation_min;
1942   op_params.quantized_activation_max = output_activation_max;
1943   op_params.output_offset = output_offset;
1944 
1945   Mul(op_params, DimsToShape(input1_dims), input1_data,
1946       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1947       output_data);
1948 }
1949 
LocalResponseNormalization(const float * input_data,const Dims<4> & input_dims,int range,float bias,float alpha,float beta,float * output_data,const Dims<4> & output_dims)1950 inline void LocalResponseNormalization(const float* input_data,
1951                                        const Dims<4>& input_dims, int range,
1952                                        float bias, float alpha, float beta,
1953                                        float* output_data,
1954                                        const Dims<4>& output_dims) {
1955   tflite::LocalResponseNormalizationParams op_params;
1956   op_params.range = range;
1957   op_params.bias = bias;
1958   op_params.alpha = alpha;
1959   op_params.beta = beta;
1960 
1961   LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
1962                              DimsToShape(output_dims), output_data);
1963 }
1964 
1965 template <typename SrcT, typename DstT>
Cast(const SrcT * input_data,const Dims<4> & input_dims,DstT * output_data,const Dims<4> & output_dims)1966 void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
1967           const Dims<4>& output_dims) {
1968   Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1969        output_data);
1970 }
1971 
Floor(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1972 inline void Floor(const float* input_data, const Dims<4>& input_dims,
1973                   float* output_data, const Dims<4>& output_dims) {
1974   Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1975         output_data);
1976 }
1977 
1978 template <typename T>
ResizeBilinear(const T * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,T * output_data,const Dims<4> & output_dims,bool align_corners)1979 inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
1980                            const int32* output_size_data,
1981                            const Dims<4>& output_size_dims, T* output_data,
1982                            const Dims<4>& output_dims, bool align_corners) {
1983   tflite::ResizeBilinearParams op_params;
1984   op_params.align_corners = align_corners;
1985   op_params.half_pixel_centers = false;
1986   ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
1987                  DimsToShape(output_size_dims), output_size_data,
1988                  DimsToShape(output_dims), output_data);
1989 }
1990 
1991 // legacy, for compatibility with old checked-in code
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims)1992 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
1993                            const int32* output_size_data,
1994                            const Dims<4>& output_size_dims, float* output_data,
1995                            const Dims<4>& output_dims) {
1996   ResizeBilinear<float>(input_data, input_dims, output_size_data,
1997                         output_size_dims, output_data, output_dims,
1998                         /*align_corners=*/false);
1999 }
2000 
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims)2001 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
2002                            const int32* output_size_data,
2003                            const Dims<4>& output_size_dims, uint8* output_data,
2004                            const Dims<4>& output_dims) {
2005   ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
2006                         output_size_dims, output_data, output_dims,
2007                         /*align_corners=*/false);
2008 }
2009 
2010 template <typename T>
SpaceToBatchND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * paddings_data,const Dims<4> & paddings_dims,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)2011 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
2012                            const int32* block_shape_data,
2013                            const Dims<4>& block_shape_dims,
2014                            const int32* paddings_data,
2015                            const Dims<4>& paddings_dims, T* output_data,
2016                            const Dims<4>& output_dims,
2017                            const int32_t pad_value) {
2018   tflite::SpaceToBatchParams op_params;
2019   op_params.output_offset = pad_value;
2020 
2021   SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
2022                  DimsToShape(block_shape_dims), block_shape_data,
2023                  DimsToShape(paddings_dims), paddings_data,
2024                  DimsToShape(output_dims), output_data);
2025 }
2026 
2027 template <typename T>
SpaceToBatchND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * paddings_data,const Dims<4> & paddings_dims,T * output_data,const Dims<4> & output_dims)2028 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
2029                            const int32* block_shape_data,
2030                            const Dims<4>& block_shape_dims,
2031                            const int32* paddings_data,
2032                            const Dims<4>& paddings_dims, T* output_data,
2033                            const Dims<4>& output_dims) {
2034   tflite::SpaceToBatchParams op_params;
2035   op_params.output_offset = 0;
2036 
2037   SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
2038                  DimsToShape(block_shape_dims), block_shape_data,
2039                  DimsToShape(paddings_dims), paddings_data,
2040                  DimsToShape(output_dims), output_data);
2041 }
2042 
2043 template <typename T>
BatchToSpaceND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * crops_data,const Dims<4> & crops_dims,T * output_data,const Dims<4> & output_dims)2044 inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
2045                            const int32* block_shape_data,
2046                            const Dims<4>& block_shape_dims,
2047                            const int32* crops_data, const Dims<4>& crops_dims,
2048                            T* output_data, const Dims<4>& output_dims) {
2049   BatchToSpaceND(DimsToShape(input_dims), input_data,
2050                  DimsToShape(block_shape_dims), block_shape_data,
2051                  DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
2052                  output_data);
2053 }
2054 
2055 // Legacy signature, function covered both Pad and PadV2.
2056 template <typename T>
PadV2(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const T pad_value)2057 inline void PadV2(const T* input_data, const Dims<4>& input_dims,
2058                   const std::vector<int>& left_paddings,
2059                   const std::vector<int>& right_paddings, T* output_data,
2060                   const Dims<4>& output_dims, const T pad_value) {
2061   TFLITE_DCHECK_EQ(left_paddings.size(), 4);
2062   TFLITE_DCHECK_EQ(right_paddings.size(), 4);
2063   tflite::PadParams op_params;
2064   op_params.left_padding_count = 4;
2065   op_params.right_padding_count = 4;
2066   for (int i = 0; i < 4; ++i) {
2067     op_params.left_padding[i] = left_paddings[3 - i];
2068     op_params.right_padding[i] = right_paddings[3 - i];
2069   }
2070   // SetFloatOrInt(pad_value, &op_params.pad_value);
2071   const T pad_value_copy = pad_value;
2072 
2073   Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
2074       DimsToShape(output_dims), output_data);
2075 }
2076 
2077 // Old Pad that calls legacy PadV2.
2078 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)2079 inline void Pad(const T* input_data, const Dims<4>& input_dims,
2080                 const std::vector<int>& left_paddings,
2081                 const std::vector<int>& right_paddings, T* output_data,
2082                 const Dims<4>& output_dims, const int32_t pad_value) {
2083   const T converted_pad_value = static_cast<T>(pad_value);
2084   PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
2085            output_dims, converted_pad_value);
2086 }
2087 
2088 // Old Pad that only padded with 0.
2089 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims)2090 inline void Pad(const T* input_data, const Dims<4>& input_dims,
2091                 const std::vector<int>& left_paddings,
2092                 const std::vector<int>& right_paddings, T* output_data,
2093                 const Dims<4>& output_dims) {
2094   const T pad_value = static_cast<T>(0);
2095   PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
2096            output_dims, pad_value);
2097 }
2098 
2099 template <typename T>
TensorFlowMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)2100 void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
2101                        const T* input2_data, T* output_data,
2102                        const Dims<4>& output_dims) {
2103   Minimum(DimsToShape(input1_dims), input1_data, input2_data,
2104           DimsToShape(output_dims), output_data);
2105 }
2106 
2107 template <typename T>
TensorFlowMaximum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)2108 void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
2109                        const T* input2_data, T* output_data,
2110                        const Dims<4>& output_dims) {
2111   Maximum(DimsToShape(input1_dims), input1_data, input2_data,
2112           DimsToShape(output_dims), output_data);
2113 }
2114 
2115 template <typename T, typename Op>
TensorFlowMaximumMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims,Op op)2116 void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
2117                               const T* input2_data, const Dims<4>& input2_dims,
2118                               T* output_data, const Dims<4>& output_dims,
2119                               Op op) {
2120   MaximumMinimumBroadcastSlow(DimsToShape(input1_dims), input1_data,
2121                               DimsToShape(input2_dims), input2_data,
2122                               DimsToShape(output_dims), output_data, op);
2123 }
2124 
2125 template <typename T1, typename T2, typename T3>
ArgMax(const T3 * axis,const T1 * input_data,const tflite::Dims<4> & input_dims,T2 * output_data,const tflite::Dims<4> & output_dims)2126 void ArgMax(const T3* axis, const T1* input_data,
2127             const tflite::Dims<4>& input_dims, T2* output_data,
2128             const tflite::Dims<4>& output_dims) {
2129   // Assumes the input always has 4 dimensions, and therefore,
2130   // output always has three dimensions.
2131   auto output_shape = RuntimeShape(
2132       {output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]});
2133   // Another way to interpret this is that output_dims.sizes[4] is always 1.
2134   TFLITE_DCHECK_EQ(output_shape.FlatSize(),
2135                    DimsToShape(output_dims).FlatSize());
2136   // Legacy path only supported this.
2137   TFLITE_DCHECK_EQ(axis[0], 3);
2138   ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape,
2139             output_data, std::greater<T1>());
2140 }
2141 
2142 template <typename T1, typename T2, typename T3, typename Cmp>
ArgMinMax(const T3 * axis,const T1 * input_data,const Dims<4> & input_dims,T2 * output_data,const Dims<4> & output_dims,const Cmp & cmp)2143 void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
2144                T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
2145   ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
2146             output_data, cmp);
2147 }
2148 
2149 template <typename T>
Pow(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)2150 inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
2151                 const T* input2_data, const Dims<4>& input2_dims,
2152                 T* output_data, const Dims<4>& output_dims) {
2153   Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
2154       input2_data, DimsToShape(output_dims), output_data);
2155 }
2156 
2157 template <typename T>
BroadcastPow(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)2158 inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
2159                          const T* input2_data, const Dims<4>& input2_dims,
2160                          T* output_data, const Dims<4>& output_dims) {
2161   BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
2162                      DimsToShape(input2_dims), input2_data,
2163                      DimsToShape(output_dims), output_data);
2164 }
2165 
2166 // R: Result type. T1: Input 1 type. T2: Input 2 type.
2167 template <typename R, typename T1, typename T2>
BroadcastBinaryFunction(const T1 * input1_data,const Dims<4> & input1_dims,const T2 * input2_data,const Dims<4> & input2_dims,R * output_data,const Dims<4> & output_dims,R (* func)(T1,T2))2168 inline void BroadcastBinaryFunction(const T1* input1_data,
2169                                     const Dims<4>& input1_dims,
2170                                     const T2* input2_data,
2171                                     const Dims<4>& input2_dims, R* output_data,
2172                                     const Dims<4>& output_dims,
2173                                     R (*func)(T1, T2)) {
2174   BroadcastBinaryFunction(DimsToShape(input1_dims), input1_data,
2175                           DimsToShape(input2_dims), input2_data,
2176                           DimsToShape(output_dims), output_data, func);
2177 }
2178 
2179 // R: Result type. T1: Input 1 type. T2: Input 2 type.
2180 template <typename R, typename T1, typename T2>
BinaryFunction(const T1 * input1_data,const Dims<4> & input1_dims,const T2 * input2_data,const Dims<4> & input2_dims,R * output_data,const Dims<4> & output_dims,R (* func)(T1,T2))2181 inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims,
2182                            const T2* input2_data, const Dims<4>& input2_dims,
2183                            R* output_data, const Dims<4>& output_dims,
2184                            R (*func)(T1, T2)) {
2185   BinaryFunction(DimsToShape(input1_dims), input1_data,
2186                  DimsToShape(input2_dims), input2_data,
2187                  DimsToShape(output_dims), output_data, func);
2188 }
2189 
2190 template <typename T>
Slice(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & begin,const std::vector<int> & size,T * output_data,const Dims<4> & output_dims)2191 inline void Slice(const T* input_data, const Dims<4>& input_dims,
2192                   const std::vector<int>& begin, const std::vector<int>& size,
2193                   T* output_data, const Dims<4>& output_dims) {
2194   tflite::SliceParams op_params;
2195   op_params.begin_count = 4;
2196   op_params.size_count = 4;
2197   for (int i = 0; i < 4; ++i) {
2198     op_params.begin[i] = begin[3 - i];
2199     op_params.size[i] = size[3 - i];
2200   }
2201 
2202   Slice(op_params, DimsToShape(input_dims), input_data,
2203         DimsToShape(output_dims), output_data);
2204 }
2205 
2206 }  // namespace reference_ops
2207 }  // namespace tflite
2208 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
2209