1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
17
18 #include <stdint.h>
19 #include <sys/types.h>
20
21 #include "public/gemmlowp.h"
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/legacy_types.h"
24 #include "tensorflow/lite/kernels/internal/reference/conv.h"
25 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
26 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
27 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
28 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
29 #include "tensorflow/lite/kernels/internal/types.h"
30
31 namespace tflite {
32
33 namespace reference_ops {
34
35 static constexpr int kDepthwiseReverseShift = -1;
36
ShapeFromDims(const tflite::Dims<4> & dims,RuntimeShape * shape)37 inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
38 shape->BuildFrom(
39 {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
40 }
41
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)42 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
43 const float* filter_data, const Dims<4>& filter_dims,
44 const float* bias_data, const Dims<4>& bias_dims,
45 int stride_width, int stride_height,
46 int dilation_width_factor, int dilation_height_factor,
47 int pad_width, int pad_height, int depth_multiplier,
48 float output_activation_min,
49 float output_activation_max, float* output_data,
50 const Dims<4>& output_dims) {
51 tflite::DepthwiseParams op_params;
52 // Padding type is ignored, but still set.
53 op_params.padding_type = PaddingType::kSame;
54 op_params.padding_values.width = pad_width;
55 op_params.padding_values.height = pad_height;
56 op_params.stride_width = stride_width;
57 op_params.stride_height = stride_height;
58 op_params.dilation_width_factor = dilation_width_factor;
59 op_params.dilation_height_factor = dilation_height_factor;
60 op_params.depth_multiplier = depth_multiplier;
61 op_params.float_activation_min = output_activation_min;
62 op_params.float_activation_max = output_activation_max;
63
64 DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
65 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
66 bias_data, DimsToShape(output_dims), output_data);
67 }
68
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)69 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
70 const float* filter_data, const Dims<4>& filter_dims,
71 const float* bias_data, const Dims<4>& bias_dims,
72 int stride_width, int stride_height, int pad_width,
73 int pad_height, int depth_multiplier,
74 float output_activation_min,
75 float output_activation_max, float* output_data,
76 const Dims<4>& output_dims) {
77 DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
78 bias_dims, stride_width, stride_height, 1, 1, pad_width,
79 pad_height, depth_multiplier, output_activation_min,
80 output_activation_max, output_data, output_dims);
81 }
82
83 // Legacy, for compatibility with old checked-in code.
84 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)85 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
86 const float* filter_data, const Dims<4>& filter_dims,
87 const float* bias_data, const Dims<4>& bias_dims,
88 int stride_width, int stride_height, int pad_width,
89 int pad_height, int depth_multiplier, float* output_data,
90 const Dims<4>& output_dims) {
91 float output_activation_min, output_activation_max;
92 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
93 DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
94 bias_dims, stride_width, stride_height, pad_width, pad_height,
95 depth_multiplier, output_activation_min, output_activation_max,
96 output_data, output_dims);
97 }
98
99 // Legacy, for compatibility with old checked-in code.
100 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)101 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
102 const float* filter_data, const Dims<4>& filter_dims,
103 const float* bias_data, const Dims<4>& bias_dims, int stride,
104 int pad_width, int pad_height, int depth_multiplier,
105 float* output_data, const Dims<4>& output_dims) {
106 DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
107 bias_dims, stride, stride, pad_width, pad_height,
108 depth_multiplier, output_data, output_dims);
109 }
110
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)111 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
112 int32 input_offset, const uint8* filter_data,
113 const Dims<4>& filter_dims, int32 filter_offset,
114 const int32* bias_data, const Dims<4>& bias_dims,
115 int stride_width, int stride_height,
116 int dilation_width_factor, int dilation_height_factor,
117 int pad_width, int pad_height, int depth_multiplier,
118 int32 output_offset, int32 output_multiplier,
119 int output_shift, int32 output_activation_min,
120 int32 output_activation_max, uint8* output_data,
121 const Dims<4>& output_dims) {
122 tflite::DepthwiseParams op_params;
123 // Padding type is ignored, but still set.
124 op_params.padding_type = PaddingType::kSame;
125 op_params.padding_values.width = pad_width;
126 op_params.padding_values.height = pad_height;
127 op_params.stride_width = stride_width;
128 op_params.stride_height = stride_height;
129 op_params.dilation_width_factor = dilation_width_factor;
130 op_params.dilation_height_factor = dilation_height_factor;
131 op_params.depth_multiplier = depth_multiplier;
132 op_params.quantized_activation_min = output_activation_min;
133 op_params.quantized_activation_max = output_activation_max;
134 op_params.input_offset = input_offset;
135 op_params.weights_offset = filter_offset;
136 op_params.output_offset = output_offset;
137 op_params.output_multiplier = output_multiplier;
138 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
139 op_params.output_shift = kDepthwiseReverseShift * output_shift;
140
141 DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
142 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
143 bias_data, DimsToShape(output_dims), output_data);
144 }
145
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)146 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
147 int32 input_offset, const uint8* filter_data,
148 const Dims<4>& filter_dims, int32 filter_offset,
149 const int32* bias_data, const Dims<4>& bias_dims,
150 int stride_width, int stride_height, int pad_width,
151 int pad_height, int depth_multiplier,
152 int32 output_offset, int32 output_multiplier,
153 int output_shift, int32 output_activation_min,
154 int32 output_activation_max, uint8* output_data,
155 const Dims<4>& output_dims) {
156 DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
157 filter_offset, bias_data, bias_dims, stride_width,
158 stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
159 output_offset, output_multiplier, output_shift,
160 output_activation_min, output_activation_max, output_data,
161 output_dims);
162 }
163
164 // Legacy, for compatibility with old checked-in code.
165 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)166 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
167 int32 input_offset, const uint8* filter_data,
168 const Dims<4>& filter_dims, int32 filter_offset,
169 const int32* bias_data, const Dims<4>& bias_dims,
170 int stride_width, int stride_height, int pad_width,
171 int pad_height, int depth_multiplier, int32 output_offset,
172 int32 output_multiplier, int output_shift,
173 int32 output_activation_min, int32 output_activation_max,
174 uint8* output_data, const Dims<4>& output_dims) {
175 if (Ac == FusedActivationFunctionType::kNone) {
176 TFLITE_DCHECK_EQ(output_activation_min, 0);
177 TFLITE_DCHECK_EQ(output_activation_max, 255);
178 }
179 DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
180 filter_offset, bias_data, bias_dims, stride_width,
181 stride_height, pad_width, pad_height, depth_multiplier,
182 output_offset, output_multiplier, output_shift,
183 output_activation_min, output_activation_max, output_data,
184 output_dims);
185 }
186
187 // Legacy, for compatibility with old checked-in code.
188 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)189 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
190 int32 input_offset, const uint8* filter_data,
191 const Dims<4>& filter_dims, int32 filter_offset,
192 const int32* bias_data, const Dims<4>& bias_dims, int stride,
193 int pad_width, int pad_height, int depth_multiplier,
194 int32 output_offset, int32 output_multiplier,
195 int output_shift, int32 output_activation_min,
196 int32 output_activation_max, uint8* output_data,
197 const Dims<4>& output_dims) {
198 DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
199 filter_dims, filter_offset, bias_data, bias_dims, stride,
200 stride, pad_width, pad_height, depth_multiplier,
201 output_offset, output_multiplier, output_shift,
202 output_activation_min, output_activation_max, output_data,
203 output_dims);
204 }
205
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)206 inline void Conv(const float* input_data, const Dims<4>& input_dims,
207 const float* filter_data, const Dims<4>& filter_dims,
208 const float* bias_data, const Dims<4>& bias_dims,
209 int stride_width, int stride_height, int dilation_width_factor,
210 int dilation_height_factor, int pad_width, int pad_height,
211 float output_activation_min, float output_activation_max,
212 float* output_data, const Dims<4>& output_dims,
213 float* im2col_data, const Dims<4>& im2col_dims) {
214 tflite::ConvParams op_params;
215 // Padding type is ignored, but still set.
216 op_params.padding_type = PaddingType::kSame;
217 op_params.padding_values.width = pad_width;
218 op_params.padding_values.height = pad_height;
219 op_params.stride_width = stride_width;
220 op_params.stride_height = stride_height;
221 op_params.dilation_width_factor = dilation_width_factor;
222 op_params.dilation_height_factor = dilation_height_factor;
223 op_params.float_activation_min = output_activation_min;
224 op_params.float_activation_max = output_activation_max;
225
226 Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
227 filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
228 output_data, DimsToShape(im2col_dims), im2col_data);
229 }
230
231 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)232 void Conv(const float* input_data, const Dims<4>& input_dims,
233 const float* filter_data, const Dims<4>& filter_dims,
234 const float* bias_data, const Dims<4>& bias_dims, int stride_width,
235 int stride_height, int dilation_width_factor,
236 int dilation_height_factor, int pad_width, int pad_height,
237 float* output_data, const Dims<4>& output_dims, float* im2col_data,
238 const Dims<4>& im2col_dims) {
239 float output_activation_min, output_activation_max;
240 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
241 Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
242 stride_width, stride_height, dilation_width_factor,
243 dilation_height_factor, pad_width, pad_height, output_activation_min,
244 output_activation_max, output_data, output_dims, im2col_data,
245 im2col_dims);
246 }
247
248 // legacy, for compatibility with old checked-in code
249 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)250 void Conv(const float* input_data, const Dims<4>& input_dims,
251 const float* filter_data, const Dims<4>& filter_dims,
252 const float* bias_data, const Dims<4>& bias_dims, int stride_width,
253 int stride_height, int pad_width, int pad_height, float* output_data,
254 const Dims<4>& output_dims, float* im2col_data,
255 const Dims<4>& im2col_dims) {
256 float output_activation_min, output_activation_max;
257 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
258 Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
259 stride_width, stride_height, 1, 1, pad_width, pad_height,
260 output_activation_min, output_activation_max, output_data, output_dims,
261 im2col_data, im2col_dims);
262 }
263
264 // legacy, for compatibility with old checked-in code
265 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)266 void Conv(const float* input_data, const Dims<4>& input_dims,
267 const float* filter_data, const Dims<4>& filter_dims,
268 const float* bias_data, const Dims<4>& bias_dims, int stride,
269 int pad_width, int pad_height, float* output_data,
270 const Dims<4>& output_dims, float* im2col_data,
271 const Dims<4>& im2col_dims) {
272 Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
273 bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
274 output_dims, im2col_data, im2col_dims);
275 }
276
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)277 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
278 int32 input_offset, const uint8* filter_data,
279 const Dims<4>& filter_dims, int32 filter_offset,
280 const int32* bias_data, const Dims<4>& bias_dims,
281 int stride_width, int stride_height, int dilation_width_factor,
282 int dilation_height_factor, int pad_width, int pad_height,
283 int32 output_offset, int32 output_multiplier, int output_shift,
284 int32 output_activation_min, int32 output_activation_max,
285 uint8* output_data, const Dims<4>& output_dims,
286 uint8* im2col_data, const Dims<4>& im2col_dims,
287 gemmlowp::GemmContext* gemmlowp_context) {
288 tflite::ConvParams op_params;
289 // Padding type is ignored, but still set.
290 op_params.padding_type = PaddingType::kSame;
291 op_params.padding_values.width = pad_width;
292 op_params.padding_values.height = pad_height;
293 op_params.stride_width = stride_width;
294 op_params.stride_height = stride_height;
295 op_params.dilation_width_factor = dilation_width_factor;
296 op_params.dilation_height_factor = dilation_height_factor;
297 op_params.input_offset = input_offset;
298 op_params.weights_offset = filter_offset;
299 op_params.output_offset = output_offset;
300 op_params.output_multiplier = output_multiplier;
301 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
302 op_params.output_shift = kReverseShift * output_shift;
303 op_params.quantized_activation_min = output_activation_min;
304 op_params.quantized_activation_max = output_activation_max;
305
306 Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
307 filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
308 output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context);
309 }
310
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)311 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
312 int32 input_offset, const uint8* filter_data,
313 const Dims<4>& filter_dims, int32 filter_offset,
314 const int32* bias_data, const Dims<4>& bias_dims,
315 int stride_width, int stride_height, int pad_width,
316 int pad_height, int32 output_offset, int32 output_multiplier,
317 int output_shift, int32 output_activation_min,
318 int32 output_activation_max, uint8* output_data,
319 const Dims<4>& output_dims, uint8* im2col_data,
320 const Dims<4>& im2col_dims,
321 gemmlowp::GemmContext* gemmlowp_context) {
322 Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
323 filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
324 pad_width, pad_height, output_offset, output_multiplier, output_shift,
325 output_activation_min, output_activation_max, output_data, output_dims,
326 im2col_data, im2col_dims, gemmlowp_context);
327 }
328
329 // legacy, for compatibility with old checked-in code
330 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)331 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
332 int32 input_offset, const uint8* filter_data,
333 const Dims<4>& filter_dims, int32 filter_offset,
334 const int32* bias_data, const Dims<4>& bias_dims,
335 int stride_width, int stride_height, int pad_width,
336 int pad_height, int32 output_offset, int32 output_multiplier,
337 int output_shift, int32 output_activation_min,
338 int32 output_activation_max, uint8* output_data,
339 const Dims<4>& output_dims, uint8* im2col_data,
340 const Dims<4>& im2col_dims,
341 gemmlowp::GemmContext* gemmlowp_context) {
342 static_assert(Ac == FusedActivationFunctionType::kNone ||
343 Ac == FusedActivationFunctionType::kRelu ||
344 Ac == FusedActivationFunctionType::kRelu6 ||
345 Ac == FusedActivationFunctionType::kRelu1,
346 "");
347 if (Ac == FusedActivationFunctionType::kNone) {
348 TFLITE_DCHECK_EQ(output_activation_min, 0);
349 TFLITE_DCHECK_EQ(output_activation_max, 255);
350 }
351 Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
352 filter_offset, bias_data, bias_dims, stride_width, stride_height,
353 pad_width, pad_height, output_offset, output_multiplier, output_shift,
354 output_activation_min, output_activation_max, output_data, output_dims,
355 im2col_data, im2col_dims, gemmlowp_context);
356 }
357
358 // legacy, for compatibility with old checked-in code
359 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)360 void Conv(const uint8* input_data, const Dims<4>& input_dims,
361 int32 input_offset, const uint8* filter_data,
362 const Dims<4>& filter_dims, int32 filter_offset,
363 const int32* bias_data, const Dims<4>& bias_dims, int stride,
364 int pad_width, int pad_height, int32 output_offset,
365 int32 output_multiplier, int output_shift,
366 int32 output_activation_min, int32 output_activation_max,
367 uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
368 const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) {
369 Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
370 filter_offset, bias_data, bias_dims, stride, stride, pad_width,
371 pad_height, output_offset, output_multiplier, output_shift,
372 output_activation_min, output_activation_max, output_data,
373 output_dims, im2col_data, im2col_dims, gemmlowp_context);
374 }
375
TransposeConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)376 inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
377 const float* filter_data, const Dims<4>& filter_dims,
378 int stride_width, int stride_height, int pad_width,
379 int pad_height, float* output_data,
380 const Dims<4>& output_dims, float* im2col_data,
381 const Dims<4>& im2col_dims) {
382 tflite::ConvParams op_params;
383 // Padding type is ignored, but still set.
384 op_params.padding_type = PaddingType::kSame;
385 op_params.padding_values.width = pad_width;
386 op_params.padding_values.height = pad_height;
387 op_params.stride_width = stride_width;
388 op_params.stride_height = stride_height;
389
390 TransposeConv(op_params, DimsToShape(input_dims), input_data,
391 DimsToShape(filter_dims), filter_data,
392 /*bias_shape*/ RuntimeShape(), /*bias*/ nullptr,
393 DimsToShape(output_dims), output_data, DimsToShape(im2col_dims),
394 im2col_data);
395 }
396
TransposeConv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)397 inline void TransposeConv(
398 const ConvParams& params, const RuntimeShape& input_shape,
399 const float* input_data, const RuntimeShape& filter_shape,
400 const float* filter_data, const RuntimeShape& output_shape,
401 float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
402 TransposeConv(params, input_shape, input_data, filter_shape, filter_data,
403 /*bias_shape*/ RuntimeShape(), /*bias*/ nullptr, output_shape,
404 output_data, im2col_shape, im2col_data);
405 }
406
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)407 inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
408 const float* weights_data,
409 const Dims<4>& weights_dims, const float* bias_data,
410 const Dims<4>& bias_dims,
411 float output_activation_min,
412 float output_activation_max, float* output_data,
413 const Dims<4>& output_dims) {
414 tflite::FullyConnectedParams op_params;
415 op_params.float_activation_min = output_activation_min;
416 op_params.float_activation_max = output_activation_max;
417
418 FullyConnected(op_params, DimsToShape(input_dims), input_data,
419 DimsToShape(weights_dims), weights_data,
420 DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
421 output_data);
422 }
423
424 // legacy, for compatibility with old checked-in code
425 template <FusedActivationFunctionType Ac>
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)426 void FullyConnected(const float* input_data, const Dims<4>& input_dims,
427 const float* weights_data, const Dims<4>& weights_dims,
428 const float* bias_data, const Dims<4>& bias_dims,
429 float* output_data, const Dims<4>& output_dims) {
430 float output_activation_min, output_activation_max;
431 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
432 FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
433 bias_dims, output_activation_min, output_activation_max,
434 output_data, output_dims);
435 }
436
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext *)437 inline void FullyConnected(
438 const FullyConnectedParams& params, const RuntimeShape& input_shape,
439 const uint8* input_data, const RuntimeShape& filter_shape,
440 const uint8* filter_data, const RuntimeShape& bias_shape,
441 const int32* bias_data, const RuntimeShape& output_shape,
442 uint8* output_data, gemmlowp::GemmContext*) {
443 FullyConnected(params, input_shape, input_data, filter_shape, filter_data,
444 bias_shape, bias_data, output_shape, output_data);
445 }
446
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,gemmlowp::GemmContext *)447 inline void FullyConnected(
448 const FullyConnectedParams& params, const RuntimeShape& input_shape,
449 const uint8* input_data, const RuntimeShape& filter_shape,
450 const uint8* filter_data, const RuntimeShape& bias_shape,
451 const int32* bias_data, const RuntimeShape& output_shape,
452 int16* output_data, gemmlowp::GemmContext*) {
453 FullyConnected(params, input_shape, input_data, filter_shape, filter_data,
454 bias_shape, bias_data, output_shape, output_data);
455 }
456
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)457 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
458 int32 input_offset, const uint8* filter_data,
459 const Dims<4>& filter_dims, int32 filter_offset,
460 const int32* bias_data, const Dims<4>& bias_dims,
461 int32 output_offset, int32 output_multiplier,
462 int output_shift, int32 output_activation_min,
463 int32 output_activation_max, uint8* output_data,
464 const Dims<4>& output_dims,
465 gemmlowp::GemmContext* gemmlowp_context) {
466 tflite::FullyConnectedParams op_params;
467 op_params.input_offset = input_offset;
468 op_params.weights_offset = filter_offset;
469 op_params.output_offset = output_offset;
470 op_params.output_multiplier = output_multiplier;
471 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
472 op_params.output_shift = kReverseShift * output_shift;
473 op_params.quantized_activation_min = output_activation_min;
474 op_params.quantized_activation_max = output_activation_max;
475
476 FullyConnected(op_params, DimsToShape(input_dims), input_data,
477 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
478 bias_data, DimsToShape(output_dims), output_data,
479 gemmlowp_context);
480 }
481
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)482 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
483 int32 input_offset, const uint8* filter_data,
484 const Dims<4>& filter_dims, int32 filter_offset,
485 const int32* bias_data, const Dims<4>& bias_dims,
486 int32 output_offset, int32 output_multiplier,
487 int output_shift, int32 output_activation_min,
488 int32 output_activation_max, int16* output_data,
489 const Dims<4>& output_dims,
490 gemmlowp::GemmContext* gemmlowp_context) {
491 tflite::FullyConnectedParams op_params;
492 op_params.input_offset = input_offset;
493 op_params.weights_offset = filter_offset;
494 op_params.output_offset = output_offset;
495 op_params.output_multiplier = output_multiplier;
496 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
497 op_params.output_shift = kReverseShift * output_shift;
498 op_params.quantized_activation_min = output_activation_min;
499 op_params.quantized_activation_max = output_activation_max;
500
501 FullyConnected(op_params, DimsToShape(input_dims), input_data,
502 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
503 bias_data, DimsToShape(output_dims), output_data,
504 gemmlowp_context);
505 }
506
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext *)507 inline void ShuffledFullyConnected(
508 const FullyConnectedParams& params, const RuntimeShape& input_shape,
509 const uint8* input_data, const RuntimeShape& weights_shape,
510 const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
511 const int32* bias_data, const RuntimeShape& output_shape,
512 int16* output_data, uint8* shuffled_input_workspace_data,
513 gemmlowp::GemmContext*) {
514 ShuffledFullyConnected(params, input_shape, input_data, weights_shape,
515 shuffled_weights_data, bias_shape, bias_data,
516 output_shape, output_data,
517 shuffled_input_workspace_data);
518 }
519
ShuffledFullyConnected(const uint8 * input_data,const Dims<4> & input_dims,const uint8 * shuffled_weights_data,const Dims<4> & weights_dims,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemmlowp_context)520 inline void ShuffledFullyConnected(
521 const uint8* input_data, const Dims<4>& input_dims,
522 const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
523 const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
524 int output_shift, int32 output_activation_min, int32 output_activation_max,
525 int16* output_data, const Dims<4>& output_dims,
526 uint8* shuffled_input_workspace_data,
527 gemmlowp::GemmContext* gemmlowp_context) {
528 tflite::FullyConnectedParams op_params;
529 op_params.output_multiplier = output_multiplier;
530 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
531 op_params.output_shift = kReverseShift * output_shift;
532 op_params.quantized_activation_min = output_activation_min;
533 op_params.quantized_activation_max = output_activation_max;
534
535 ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
536 DimsToShape(weights_dims), shuffled_weights_data,
537 DimsToShape(bias_dims), bias_data,
538 DimsToShape(output_dims), output_data,
539 shuffled_input_workspace_data, gemmlowp_context);
540 }
541
542 // legacy, for compatibility with old checked-in code
543 template <FusedActivationFunctionType Ac>
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)544 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
545 int32 input_offset, const uint8* filter_data,
546 const Dims<4>& filter_dims, int32 filter_offset,
547 const int32* bias_data, const Dims<4>& bias_dims,
548 int32 output_offset, int32 output_multiplier,
549 int output_shift, int32 output_activation_min,
550 int32 output_activation_max, uint8* output_data,
551 const Dims<4>& output_dims,
552 gemmlowp::GemmContext* gemmlowp_context) {
553 static_assert(Ac == FusedActivationFunctionType::kNone ||
554 Ac == FusedActivationFunctionType::kRelu ||
555 Ac == FusedActivationFunctionType::kRelu6 ||
556 Ac == FusedActivationFunctionType::kRelu1,
557 "");
558 if (Ac == FusedActivationFunctionType::kNone) {
559 TFLITE_DCHECK_EQ(output_activation_min, 0);
560 TFLITE_DCHECK_EQ(output_activation_max, 255);
561 }
562 FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
563 filter_offset, bias_data, bias_dims, output_offset,
564 output_multiplier, output_shift, output_activation_min,
565 output_activation_max, output_data, output_dims,
566 gemmlowp_context);
567 }
568
LstmCell(const float * input_data,const Dims<4> & input_dims,const float * prev_activ_data,const Dims<4> & prev_activ_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,const float * prev_state_data,const Dims<4> & prev_state_dims,float * output_state_data,const Dims<4> & output_state_dims,float * output_activ_data,const Dims<4> & output_activ_dims,float * concat_temp_data,const Dims<4> & concat_temp_dims,float * activ_temp_data,const Dims<4> & activ_temp_dims)569 inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
570 const float* prev_activ_data,
571 const Dims<4>& prev_activ_dims, const float* weights_data,
572 const Dims<4>& weights_dims, const float* bias_data,
573 const Dims<4>& bias_dims, const float* prev_state_data,
574 const Dims<4>& prev_state_dims, float* output_state_data,
575 const Dims<4>& output_state_dims, float* output_activ_data,
576 const Dims<4>& output_activ_dims, float* concat_temp_data,
577 const Dims<4>& concat_temp_dims, float* activ_temp_data,
578 const Dims<4>& activ_temp_dims) {
579 tflite::LstmCellParams op_params;
580 // Float LSTM cell does not need parameters to be set: leave untouched.
581
582 LstmCell(op_params, DimsToShape(input_dims), input_data,
583 DimsToShape(prev_activ_dims), prev_activ_data,
584 DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
585 bias_data, DimsToShape(prev_state_dims), prev_state_data,
586 DimsToShape(output_state_dims), output_state_data,
587 DimsToShape(output_activ_dims), output_activ_data,
588 DimsToShape(concat_temp_dims), concat_temp_data,
589 DimsToShape(activ_temp_dims), activ_temp_data);
590 }
591
592 template <int StateIntegerBits>
LstmCell(const uint8 * input_data_uint8,const Dims<4> & input_dims,const uint8 * prev_activ_data_uint8,const Dims<4> & prev_activ_dims,const uint8 * weights_data_uint8,const Dims<4> & weights_dims,const int32 * bias_data_int32,const Dims<4> & bias_dims,const int16 * prev_state_data_int16,const Dims<4> & prev_state_dims,int16 * output_state_data_int16,const Dims<4> & output_state_dims,uint8 * output_activ_data_uint8,const Dims<4> & output_activ_dims,uint8 * concat_temp_data_uint8,const Dims<4> & concat_temp_dims,int16 * activ_temp_data_int16,const Dims<4> & activ_temp_dims,int32 weights_zero_point,int32 accum_multiplier,int accum_shift,gemmlowp::GemmContext * gemmlowp_context)593 void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
594 const uint8* prev_activ_data_uint8,
595 const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
596 const Dims<4>& weights_dims, const int32* bias_data_int32,
597 const Dims<4>& bias_dims, const int16* prev_state_data_int16,
598 const Dims<4>& prev_state_dims, int16* output_state_data_int16,
599 const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
600 const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
601 const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
602 const Dims<4>& activ_temp_dims, int32 weights_zero_point,
603 int32 accum_multiplier, int accum_shift,
604 gemmlowp::GemmContext* gemmlowp_context) {
605 tflite::LstmCellParams op_params;
606 op_params.weights_zero_point = weights_zero_point;
607 op_params.accum_multiplier = accum_multiplier;
608 op_params.accum_shift = accum_shift;
609
610 LstmCell<StateIntegerBits>(
611 op_params, DimsToShape(input_dims), input_data_uint8,
612 DimsToShape(prev_activ_dims), prev_activ_data_uint8,
613 DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
614 bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
615 DimsToShape(output_state_dims), output_state_data_int16,
616 DimsToShape(output_activ_dims), output_activ_data_uint8,
617 DimsToShape(concat_temp_dims), concat_temp_data_uint8,
618 DimsToShape(activ_temp_dims), activ_temp_data_int16, gemmlowp_context);
619 }
620
621 template <typename T>
BroadcastDiv(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)622 void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
623 const T* input2_data, const Dims<4>& input2_dims,
624 T output_activation_min, T output_activation_max,
625 T* output_data, const Dims<4>& output_dims) {
626 tflite::ArithmeticParams op_params;
627 SetActivationParams(output_activation_min, output_activation_max, &op_params);
628
629 BroadcastDivSlow(op_params, DimsToShape(input1_dims), input1_data,
630 DimsToShape(input2_dims), input2_data,
631 DimsToShape(output_dims), output_data);
632 }
633
634 template <typename T>
Div(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)635 inline void Div(const T* input1_data, const Dims<4>& input1_dims,
636 const T* input2_data, const Dims<4>& input2_dims,
637 T output_activation_min, T output_activation_max,
638 T* output_data, const Dims<4>& output_dims) {
639 tflite::ArithmeticParams op_params;
640 SetActivationParams(output_activation_min, output_activation_max, &op_params);
641
642 Div(op_params, DimsToShape(input1_dims), input1_data,
643 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
644 output_data);
645 }
646
647 template <FusedActivationFunctionType Ac, typename Scalar>
Concatenation(int concat_dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)648 inline void Concatenation(int concat_dim, const Scalar* const* input_data,
649 const Dims<4>* const* input_dims, int inputs_count,
650 Scalar* output_data, const Dims<4>& output_dims) {
651 // For now we don't have a model with a Concatenation with fused activation.
652 TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
653
654 std::vector<RuntimeShape> input_shapes(inputs_count);
655 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
656 for (int i = 0; i < inputs_count; ++i) {
657 ShapeFromDims(*input_dims[i], &input_shapes[i]);
658 input_shapes_indirect[i] = &input_shapes[i];
659 }
660 tflite::ConcatenationParams op_params;
661 op_params.axis = 3 - concat_dim;
662 op_params.inputs_count = inputs_count;
663
664 Concatenation(op_params, input_shapes_indirect.data(), input_data,
665 DimsToShape(output_dims), output_data);
666 }
667
Concatenation(int concat_dim,const uint8 * const * input_data,const Dims<4> * const * input_dims,const int32 * input_zeropoint,const float * input_scale,int inputs_count,uint8 * output_data,const Dims<4> & output_dims,const int32 output_zeropoint,const float output_scale)668 inline void Concatenation(int concat_dim, const uint8* const* input_data,
669 const Dims<4>* const* input_dims,
670 const int32* input_zeropoint,
671 const float* input_scale, int inputs_count,
672 uint8* output_data, const Dims<4>& output_dims,
673 const int32 output_zeropoint,
674 const float output_scale) {
675 std::vector<RuntimeShape> input_shapes(inputs_count);
676 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
677 for (int i = 0; i < inputs_count; ++i) {
678 ShapeFromDims(*input_dims[i], &input_shapes[i]);
679 input_shapes_indirect[i] = &input_shapes[i];
680 }
681 tflite::ConcatenationParams op_params;
682 op_params.axis = 3 - concat_dim;
683 op_params.input_zeropoint = input_zeropoint;
684 op_params.input_scale = input_scale;
685 op_params.inputs_count = inputs_count;
686 op_params.output_zeropoint = output_zeropoint;
687 op_params.output_scale = output_scale;
688
689 ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
690 DimsToShape(output_dims), output_data);
691 }
692
693 template <FusedActivationFunctionType Ac, typename Scalar>
DepthConcatenation(const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)694 void DepthConcatenation(const Scalar* const* input_data,
695 const Dims<4>* const* input_dims, int inputs_count,
696 Scalar* output_data, const Dims<4>& output_dims) {
697 // For now we don't have a model with a Concatenation with fused activation.
698 TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
699
700 std::vector<RuntimeShape> input_shapes(inputs_count);
701 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
702 for (int i = 0; i < inputs_count; ++i) {
703 ShapeFromDims(*input_dims[i], &input_shapes[i]);
704 input_shapes_indirect[i] = &input_shapes[i];
705 }
706 tflite::ConcatenationParams op_params;
707 op_params.inputs_count = inputs_count;
708
709 DepthConcatenation(op_params, input_shapes_indirect.data(), input_data,
710 DimsToShape(output_dims), output_data);
711 }
712
713 template <typename Scalar>
TensorFlowSplit(const Scalar * input_data,const Dims<4> & input_dims,int axis,int outputs_count,Scalar * const * output_data,const Dims<4> * const * output_dims)714 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
715 int axis, int outputs_count, Scalar* const* output_data,
716 const Dims<4>* const* output_dims) {
717 std::vector<RuntimeShape> output_shapes(outputs_count);
718 std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
719 for (int i = 0; i < outputs_count; ++i) {
720 ShapeFromDims(*output_dims[i], &output_shapes[i]);
721 output_shapes_indirect[i] = &output_shapes[i];
722 }
723 tflite::SplitParams op_params;
724 op_params.axis = 3 - axis;
725 op_params.num_split = outputs_count;
726
727 Split(op_params, DimsToShape(input_dims), input_data,
728 output_shapes_indirect.data(), output_data);
729 }
730
731 template <FusedActivationFunctionType Ac, typename Scalar>
TensorFlowSplit(const Scalar * input_data,const Dims<4> & input_dims,int outputs_count,Scalar * const * output_data,const Dims<4> * const * output_dims)732 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
733 int outputs_count, Scalar* const* output_data,
734 const Dims<4>* const* output_dims) {
735 TFLITE_DCHECK_GE(outputs_count, 1);
736 for (int i = 0; i < outputs_count; i++) {
737 /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
738 /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
739 /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
740 }
741 // For now we don't have a model with a Split with fused activation.
742 TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
743
744 TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
745 output_data, output_dims);
746 }
747
Softmax(const float * input_data,const RuntimeShape & input_shape,float beta,float * output_data,const RuntimeShape & output_shape)748 inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
749 float beta, float* output_data,
750 const RuntimeShape& output_shape) {
751 SoftmaxParams params;
752 params.beta = beta;
753 Softmax(params, input_shape, input_data, output_shape, output_data);
754 }
755
Softmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)756 inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
757 int32 input_beta_multiplier, int32 input_beta_left_shift,
758 int diff_min, uint8* output_data,
759 const RuntimeShape& output_shape) {
760 SoftmaxParams params;
761 params.input_multiplier = input_beta_multiplier;
762 params.input_left_shift = input_beta_left_shift;
763 params.diff_min = diff_min;
764 Softmax(params, input_shape, input_data, output_shape, output_data);
765 }
766
LogSoftmax(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)767 inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
768 float* output_data, const RuntimeShape& output_shape) {
769 SoftmaxParams params;
770 // No params currently used for float LogSoftmax.
771 LogSoftmax(params, input_shape, input_data, output_shape, output_data);
772 }
773
LogSoftmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)774 inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
775 int32 input_multiplier, int32 input_left_shift,
776 int32 reverse_scaling_divisor,
777 int32 reverse_scaling_right_shift, int diff_min,
778 uint8* output_data, const RuntimeShape& output_shape) {
779 SoftmaxParams params;
780 params.input_multiplier = input_multiplier;
781 params.input_left_shift = input_left_shift;
782 params.reverse_scaling_divisor = reverse_scaling_divisor;
783 params.reverse_scaling_right_shift = reverse_scaling_right_shift;
784 params.diff_min = diff_min;
785 LogSoftmax(params, input_shape, input_data, output_shape, output_data);
786 }
787
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)788 inline void Logistic(const LogisticParams& params,
789 const RuntimeShape& input_shape, const uint8* input_data,
790 const RuntimeShape& output_shape, uint8* output_data) {
791 const int32 input_zero_point = params.input_zero_point;
792 const int32 input_range_radius = params.input_range_radius;
793 const int32 input_multiplier = params.input_multiplier;
794 const int input_left_shift = params.input_left_shift;
795 const int flat_size = MatchingFlatSize(input_shape, output_shape);
796
797 for (int i = 0; i < flat_size; i++) {
798 const uint8 input_val_u8 = input_data[i];
799 const int32 input_val_centered =
800 static_cast<int32>(input_val_u8) - input_zero_point;
801 uint8 output_val;
802 if (input_val_centered <= -input_range_radius) {
803 output_val = 0;
804 } else if (input_val_centered >= input_range_radius) {
805 output_val = 255;
806 } else {
807 const int32 input_val_rescaled =
808 MultiplyByQuantizedMultiplierGreaterThanOne(
809 input_val_centered, input_multiplier, input_left_shift);
810 using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
811 using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
812 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
813 const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
814 // Convert from Q0.31 to Q23.8.
815 using gemmlowp::RoundingDivideByPOT;
816 int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
817 if (output_val_s32 == 256) {
818 output_val_s32 = 255;
819 }
820 // Reinterpret as U0.8.
821 TFLITE_DCHECK_GE(output_val_s32, 0);
822 TFLITE_DCHECK_LE(output_val_s32, 255);
823 output_val = static_cast<uint8>(output_val_s32);
824 }
825 output_data[i] = output_val;
826 }
827 }
828
Logistic(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)829 inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
830 int32 input_zero_point, int32 input_range_radius,
831 int32 input_multiplier, int input_left_shift,
832 uint8* output_data, const RuntimeShape& output_shape) {
833 LogisticParams params;
834 params.input_zero_point = input_zero_point;
835 params.input_range_radius = input_range_radius;
836 params.input_multiplier = input_multiplier;
837 params.input_left_shift = input_left_shift;
838 Logistic(params, input_shape, input_data, output_shape, output_data);
839 }
840
Logistic(const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)841 inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
842 const RuntimeShape& output_shape, int16* output_data) {
843 LogisticParams params;
844 // No params currently needed by int16 Logistic.
845 Logistic(params, input_shape, input_data, output_shape, output_data);
846 }
847
Tanh(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)848 inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
849 int32 input_zero_point, int32 input_range_radius,
850 int32 input_multiplier, int input_left_shift,
851 uint8* output_data, const RuntimeShape& output_shape) {
852 TanhParams params;
853 params.input_zero_point = input_zero_point;
854 params.input_range_radius = input_range_radius;
855 params.input_multiplier = input_multiplier;
856 params.input_left_shift = input_left_shift;
857 Tanh(params, input_shape, input_data, output_shape, output_data);
858 }
859
Tanh(const int16 * input_data,const RuntimeShape & input_shape,int input_left_shift,int16 * output_data,const RuntimeShape & output_shape)860 inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
861 int input_left_shift, int16* output_data,
862 const RuntimeShape& output_shape) {
863 TanhParams params;
864 params.input_left_shift = input_left_shift;
865 Tanh(params, input_shape, input_data, output_shape, output_data);
866 }
867
Dequantize(const uint8 * input_data,const Dims<4> & input_dims,int32 zero_point,double scale,float * output_data,const Dims<4> & output_dims)868 inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
869 int32 zero_point, double scale, float* output_data,
870 const Dims<4>& output_dims) {
871 tflite::DequantizationParams op_params;
872 op_params.zero_point = zero_point;
873 op_params.scale = scale;
874
875 Dequantize(op_params, DimsToShape(input_dims), input_data,
876 DimsToShape(output_dims), output_data);
877 }
878
FakeQuant(const float * input_data,const Dims<4> & input_dims,float rmin,float rmax,int num_bits,float * output_data,const Dims<4> & output_dims)879 inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
880 float rmin, float rmax, int num_bits, float* output_data,
881 const Dims<4>& output_dims) {
882 tflite::FakeQuantParams op_params;
883 op_params.num_bits = num_bits;
884 op_params.minmax.min = rmin;
885 op_params.minmax.max = rmax;
886
887 FakeQuant(op_params, DimsToShape(input_dims), input_data,
888 DimsToShape(output_dims), output_data);
889 }
890
891 template <typename T>
Gather(const T * input_data,const Dims<4> & input_dims,int input_rank,const int32 * coords_data,const Dims<4> & coords_dims,T * output_data,const Dims<4> & output_dims)892 inline void Gather(const T* input_data, const Dims<4>& input_dims,
893 int input_rank, const int32* coords_data,
894 const Dims<4>& coords_dims, T* output_data,
895 const Dims<4>& output_dims) {
896 tflite::GatherParams op_params;
897 op_params.axis = 4 - input_rank;
898
899 Gather(op_params, DimsToShape(input_dims), input_data,
900 DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
901 output_data);
902 }
903
LegacyReverseBits32(uint32 n)904 inline uint32 LegacyReverseBits32(uint32 n) {
905 n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
906 n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
907 n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
908 return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
909 ((n & 0xFF000000) >> 24));
910 }
911
StridedSliceReverseIndices(tflite::StridedSliceParams * p)912 inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
913 TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
914 TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
915
916 std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
917 std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
918 std::reverse(p->strides, p->strides + p->strides_count);
919
920 p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
921 (32 - p->start_indices_count);
922 p->ellipsis_mask =
923 LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
924 (32 - p->start_indices_count);
925 p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
926 (32 - p->start_indices_count);
927 p->new_axis_mask =
928 LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
929 (32 - p->start_indices_count);
930 p->shrink_axis_mask =
931 LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
932 (32 - p->start_indices_count);
933 }
934
935 template <typename T>
StridedSlice(const T * input_data,const Dims<4> & input_dims,int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides,T * output_data,const Dims<4> & output_dims)936 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
937 int begin_mask, int end_mask, int shrink_axis_mask,
938 const std::vector<int>& start_indices,
939 const std::vector<int>& stop_indices,
940 const std::vector<int>& strides, T* output_data,
941 const Dims<4>& output_dims) {
942 TFLITE_DCHECK_EQ(start_indices.size(), 4);
943 auto op_params = strided_slice::BuildStridedSliceParams(
944 begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
945 strides);
946 StridedSliceReverseIndices(&op_params);
947
948 StridedSlice(op_params, DimsToShape(input_dims), input_data,
949 DimsToShape(output_dims), output_data);
950 }
951
952 template <typename T>
Mean(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & reduction_indices,T * output_data,const Dims<4> & output_dims)953 inline void Mean(const T* input_data, const Dims<4>& input_dims,
954 const std::vector<int>& reduction_indices, T* output_data,
955 const Dims<4>& output_dims) {
956 tflite::MeanParams op_params;
957 op_params.axis_count = reduction_indices.size();
958 for (int i = 0; i < op_params.axis_count; ++i) {
959 op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
960 }
961
962 Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
963 output_data);
964 }
965
966 template <typename T>
Transpose(const T * input,const Dims<4> & input_dims,T * output,const Dims<4> & output_dims,const int * permuted_axes)967 void Transpose(const T* input, const Dims<4>& input_dims, T* output,
968 const Dims<4>& output_dims, const int* permuted_axes) {
969 TransposeParams params;
970 params.perm_count = 4;
971 for (int i = 0; i < 4; ++i) {
972 params.perm[i] = 3 - permuted_axes[3 - i];
973 }
974 Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
975 output);
976 }
977
978 template <typename T, ComparisonFn<T> F>
Comparison(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims)979 inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
980 const T* input2_data, const Dims<4>& input2_dims,
981 bool* output_data, const Dims<4>& output_dims) {
982 ComparisonParams op_params;
983 // No parameters needed.
984 ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
985 DimsToShape(input2_dims), input2_data,
986 DimsToShape(output_dims), output_data);
987 }
988
989 template <typename T, ComparisonFn<int32> F>
Comparison(int left_shift,const T * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const T * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,bool * output_data,const Dims<4> & output_dims)990 inline void Comparison(int left_shift, const T* input1_data,
991 const Dims<4>& input1_dims, int32 input1_offset,
992 int32 input1_multiplier, int input1_shift,
993 const T* input2_data, const Dims<4>& input2_dims,
994 int32 input2_offset, int32 input2_multiplier,
995 int input2_shift, bool* output_data,
996 const Dims<4>& output_dims) {
997 tflite::ComparisonParams op_params;
998 op_params.left_shift = left_shift;
999 op_params.input1_offset = input1_offset;
1000 op_params.input1_multiplier = input1_multiplier;
1001 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1002 op_params.input1_shift = kReverseShift * input1_shift;
1003 op_params.input2_offset = input2_offset;
1004 op_params.input2_multiplier = input2_multiplier;
1005 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1006 op_params.input2_shift = kReverseShift * input2_shift;
1007
1008 ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
1009 DimsToShape(input2_dims), input2_data,
1010 DimsToShape(output_dims), output_data);
1011 }
1012
1013 template <typename T, ComparisonFn<T> F>
BroadcastComparison(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims)1014 inline void BroadcastComparison(const T* input1_data,
1015 const Dims<4>& input1_dims,
1016 const T* input2_data,
1017 const Dims<4>& input2_dims, bool* output_data,
1018 const Dims<4>& output_dims) {
1019 ComparisonParams op_params;
1020 // No parameters needed.
1021 BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
1022 input1_data, DimsToShape(input2_dims),
1023 input2_data, DimsToShape(output_dims),
1024 output_data);
1025 }
1026
1027 template <typename T, ComparisonFn<int32> F>
BroadcastComparison(int left_shift,const T * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const T * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,bool * output_data,const Dims<4> & output_dims)1028 inline void BroadcastComparison(int left_shift, const T* input1_data,
1029 const Dims<4>& input1_dims, int32 input1_offset,
1030 int32 input1_multiplier, int input1_shift,
1031 const T* input2_data,
1032 const Dims<4>& input2_dims, int32 input2_offset,
1033 int32 input2_multiplier, int input2_shift,
1034 bool* output_data, const Dims<4>& output_dims) {
1035 ComparisonParams op_params;
1036
1037 op_params.left_shift = left_shift;
1038 op_params.input1_offset = input1_offset;
1039 op_params.input1_multiplier = input1_multiplier;
1040 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1041 op_params.input1_shift = kReverseShift * input1_shift;
1042 op_params.input2_offset = input2_offset;
1043 op_params.input2_multiplier = input2_multiplier;
1044 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1045 op_params.input2_shift = kReverseShift * input2_shift;
1046
1047 BroadcastComparison4DSlowWithScaling<T, F>(
1048 op_params, DimsToShape(input1_dims), input1_data,
1049 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1050 output_data);
1051 }
1052
1053 #define TFLITE_LEGACY_COMPARISON_OP(name) \
1054 template <typename T> \
1055 inline void name(const T* input1_data, const Dims<4>& input1_dims, \
1056 const T* input2_data, const Dims<4>& input2_dims, \
1057 bool* output_data, const Dims<4>& output_dims) { \
1058 ruy::profiler::ScopeLabel label(#name); \
1059 Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
1060 input2_dims, output_data, output_dims); \
1061 } \
1062 template <typename T> \
1063 inline void name( \
1064 int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
1065 int32 input1_offset, int32 input1_multiplier, int input1_shift, \
1066 const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
1067 int32 input2_multiplier, int input2_shift, bool* output_data, \
1068 const Dims<4>& output_dims) { \
1069 ruy::profiler::ScopeLabel label(#name "/8bit"); \
1070 Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
1071 input1_offset, input1_multiplier, input1_shift, \
1072 input2_data, input2_dims, input2_offset, \
1073 input2_multiplier, input2_shift, output_data, \
1074 output_dims); \
1075 } \
1076 template <typename T> \
1077 inline void Broadcast##name( \
1078 const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
1079 const Dims<4>& input2_dims, bool* output_data, \
1080 const Dims<4>& output_dims) { \
1081 ruy::profiler::ScopeLabel label("Broadcast" #name); \
1082 BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
1083 input2_dims, output_data, output_dims); \
1084 } \
1085 template <typename T> \
1086 inline void Broadcast##name( \
1087 int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
1088 int32 input1_offset, int32 input1_multiplier, int input1_shift, \
1089 const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
1090 int32 input2_multiplier, int input2_shift, bool* output_data, \
1091 const Dims<4>& output_dims) { \
1092 ruy::profiler::ScopeLabel label("Broadcast" #name "/8bit"); \
1093 BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
1094 input1_offset, input1_multiplier, \
1095 input1_shift, input2_data, input2_dims, \
1096 input2_offset, input2_multiplier, \
1097 input2_shift, output_data, output_dims); \
1098 }
1099 TFLITE_LEGACY_COMPARISON_OP(Equal);
1100 TFLITE_LEGACY_COMPARISON_OP(NotEqual);
1101 TFLITE_LEGACY_COMPARISON_OP(Greater);
1102 TFLITE_LEGACY_COMPARISON_OP(GreaterEqual);
1103 TFLITE_LEGACY_COMPARISON_OP(Less);
1104 TFLITE_LEGACY_COMPARISON_OP(LessEqual);
1105 #undef TFLITE_LEGACY_COMPARISON_OP
1106
1107 template <typename D, typename T>
Select(const D * input_condition_data,const Dims<4> & input_condition_dims,const T * input_x_data,const Dims<4> & input_x_dims,const T * input_y_data,const Dims<4> & input_y_dims,T * output_data,const Dims<4> & output_dims)1108 inline void Select(const D* input_condition_data,
1109 const Dims<4>& input_condition_dims, const T* input_x_data,
1110 const Dims<4>& input_x_dims, const T* input_y_data,
1111 const Dims<4>& input_y_dims, T* output_data,
1112 const Dims<4>& output_dims) {
1113 Select(DimsToShape(input_condition_dims), input_condition_data,
1114 DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
1115 input_y_data, DimsToShape(output_dims), output_data);
1116 }
1117
1118 template <typename D, typename T>
RankOneSelect(const D * input_condition_data,const Dims<4> & input_condition_dims,const T * input_x_data,const Dims<4> & input_x_dims,const T * input_y_data,const Dims<4> & input_y_dims,T * output_data,const Dims<4> & output_dims)1119 inline void RankOneSelect(const D* input_condition_data,
1120 const Dims<4>& input_condition_dims,
1121 const T* input_x_data, const Dims<4>& input_x_dims,
1122 const T* input_y_data, const Dims<4>& input_y_dims,
1123 T* output_data, const Dims<4>& output_dims) {
1124 RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
1125 DimsToShape(input_x_dims), input_x_data,
1126 DimsToShape(input_y_dims), input_y_data,
1127 DimsToShape(output_dims), output_data);
1128 }
1129
1130 template <typename T, typename TI>
SparseToDense(const std::vector<std::vector<TI>> & indices,const T * values,T default_value,T * output_data,const Dims<4> & output_dims,bool value_is_scalar)1131 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
1132 const T* values, T default_value, T* output_data,
1133 const Dims<4>& output_dims, bool value_is_scalar) {
1134 SparseToDense(indices, values, default_value, value_is_scalar,
1135 DimsToShape(output_dims), output_data);
1136 }
1137
1138 template <typename Scalar>
Pack(int dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)1139 void Pack(int dim, const Scalar* const* input_data,
1140 const Dims<4>* const* input_dims, int inputs_count,
1141 Scalar* output_data, const Dims<4>& output_dims) {
1142 std::vector<RuntimeShape> input_shapes(inputs_count);
1143 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
1144 for (int i = 0; i < inputs_count; ++i) {
1145 ShapeFromDims(*input_dims[i], &input_shapes[i]);
1146 input_shapes_indirect[i] = &input_shapes[i];
1147 }
1148 tflite::PackParams op_params;
1149 op_params.axis = 3 - dim;
1150 op_params.inputs_count = inputs_count;
1151
1152 Pack(op_params, input_shapes_indirect.data(), input_data,
1153 DimsToShape(output_dims), output_data);
1154 }
1155
1156 template <typename Scalar>
Unpack(int axis,const Scalar * input_data,const Dims<4> & input_dims,int dimensions,int outputs_count,Scalar * const * output_datas,const Dims<4> & output_dims)1157 void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
1158 int dimensions, int outputs_count, Scalar* const* output_datas,
1159 const Dims<4>& output_dims) {
1160 tflite::UnpackParams op_params;
1161 op_params.axis = 3 - axis;
1162 op_params.num_split = outputs_count;
1163
1164 Unpack(op_params, DimsToShape(input_dims), input_data,
1165 DimsToShape(output_dims), output_datas);
1166 }
1167
1168 template <typename Scalar>
Pack(int dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,const int32 * input_zeropoint,const float * input_scale,int inputs_count,Scalar * output_data,const Dims<4> & output_dims,const int32 output_zeropoint,const float output_scale)1169 void Pack(int dim, const Scalar* const* input_data,
1170 const Dims<4>* const* input_dims, const int32* input_zeropoint,
1171 const float* input_scale, int inputs_count, Scalar* output_data,
1172 const Dims<4>& output_dims, const int32 output_zeropoint,
1173 const float output_scale) {
1174 std::vector<RuntimeShape> input_shapes(inputs_count);
1175 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
1176 for (int i = 0; i < inputs_count; ++i) {
1177 ShapeFromDims(*input_dims[i], &input_shapes[i]);
1178 input_shapes_indirect[i] = &input_shapes[i];
1179 }
1180 tflite::PackParams op_params;
1181 op_params.axis = 3 - dim;
1182 op_params.input_zeropoint = input_zeropoint;
1183 op_params.input_scale = input_scale;
1184 op_params.inputs_count = inputs_count;
1185 op_params.output_zeropoint = output_zeropoint;
1186 op_params.output_scale = output_scale;
1187
1188 PackWithScaling(op_params, input_shapes_indirect.data(), input_data,
1189 DimsToShape(output_dims), output_data);
1190 }
1191
1192 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)1193 void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
1194 float* output_data, const RuntimeShape& output_shape) {
1195 static_assert(Ac == FusedActivationFunctionType::kNone, "");
1196 tflite::L2NormalizationParams op_params;
1197 // No params need to be set for float.
1198
1199 L2Normalization(op_params, input_shape, input_data, output_shape,
1200 output_data);
1201 }
1202
L2Normalization(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,uint8 * output_data,const RuntimeShape & output_shape)1203 inline void L2Normalization(const uint8* input_data,
1204 const RuntimeShape& input_shape,
1205 int32 input_zero_point, uint8* output_data,
1206 const RuntimeShape& output_shape) {
1207 tflite::L2NormalizationParams op_params;
1208 op_params.input_zero_point = input_zero_point;
1209
1210 L2Normalization(op_params, input_shape, input_data, output_shape,
1211 output_data);
1212 }
1213
1214 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1215 void L2Normalization(const float* input_data, const Dims<4>& input_dims,
1216 float* output_data, const Dims<4>& output_dims) {
1217 L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
1218 DimsToShape(output_dims));
1219 }
1220
L2Normalization(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,uint8 * output_data,const Dims<4> & output_dims)1221 inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
1222 int32 input_zero_point, uint8* output_data,
1223 const Dims<4>& output_dims) {
1224 L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
1225 output_data, DimsToShape(output_dims));
1226 }
1227
Relu(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1228 inline void Relu(const float* input_data, const Dims<4>& input_dims,
1229 float* output_data, const Dims<4>& output_dims) {
1230 Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1231 output_data);
1232 }
1233
Relu1(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1234 inline void Relu1(const float* input_data, const Dims<4>& input_dims,
1235 float* output_data, const Dims<4>& output_dims) {
1236 Relu1(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1237 output_data);
1238 }
1239
Relu6(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1240 inline void Relu6(const float* input_data, const Dims<4>& input_dims,
1241 float* output_data, const Dims<4>& output_dims) {
1242 Relu6(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1243 output_data);
1244 }
1245
ReluX(uint8 min_value,uint8 max_value,const uint8 * input_data,const RuntimeShape & input_shape,uint8 * output_data,const RuntimeShape & output_shape)1246 inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
1247 const RuntimeShape& input_shape, uint8* output_data,
1248 const RuntimeShape& output_shape) {
1249 tflite::ActivationParams params;
1250 params.quantized_activation_max = max_value;
1251 params.quantized_activation_min = min_value;
1252 ReluX(params, input_shape, input_data, output_shape, output_data);
1253 }
1254
1255 template <FusedActivationFunctionType Ac>
Add(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1256 inline void Add(int left_shift, const uint8* input1_data,
1257 const Dims<4>& input1_dims, int32 input1_offset,
1258 int32 input1_multiplier, int input1_shift,
1259 const uint8* input2_data, const Dims<4>& input2_dims,
1260 int32 input2_offset, int32 input2_multiplier, int input2_shift,
1261 int32 output_offset, int32 output_multiplier, int output_shift,
1262 int32 output_activation_min, int32 output_activation_max,
1263 uint8* output_data, const Dims<4>& output_dims) {
1264 constexpr int kReverseShift = -1;
1265 static_assert(Ac == FusedActivationFunctionType::kNone ||
1266 Ac == FusedActivationFunctionType::kRelu ||
1267 Ac == FusedActivationFunctionType::kRelu6 ||
1268 Ac == FusedActivationFunctionType::kRelu1,
1269 "");
1270 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1271 if (Ac == FusedActivationFunctionType::kNone) {
1272 TFLITE_DCHECK_EQ(output_activation_min, 0);
1273 TFLITE_DCHECK_EQ(output_activation_max, 255);
1274 }
1275
1276 tflite::ArithmeticParams op_params;
1277 op_params.left_shift = left_shift;
1278 op_params.input1_offset = input1_offset;
1279 op_params.input1_multiplier = input1_multiplier;
1280 op_params.input1_shift = kReverseShift * input1_shift;
1281 op_params.input2_offset = input2_offset;
1282 op_params.input2_multiplier = input2_multiplier;
1283 op_params.input2_shift = kReverseShift * input2_shift;
1284 op_params.output_offset = output_offset;
1285 op_params.output_multiplier = output_multiplier;
1286 op_params.output_shift = kReverseShift * output_shift;
1287 op_params.quantized_activation_min = output_activation_min;
1288 op_params.quantized_activation_max = output_activation_max;
1289 Add(op_params, DimsToShape(input1_dims), input1_data,
1290 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1291 output_data);
1292 }
1293
1294 template <FusedActivationFunctionType Ac>
Add(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)1295 void Add(const int32* input1_data, const Dims<4>& input1_dims,
1296 const int32* input2_data, const Dims<4>& input2_dims,
1297 int32* output_data, const Dims<4>& output_dims) {
1298 ruy::profiler::ScopeLabel label("Add/int32");
1299 TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
1300
1301 tflite::ArithmeticParams op_params;
1302 op_params.quantized_activation_min = std::numeric_limits<int32>::min();
1303 op_params.quantized_activation_max = std::numeric_limits<int32>::max();
1304 Add(op_params, DimsToShape(input1_dims), input1_data,
1305 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1306 output_data);
1307 }
1308
1309 template <FusedActivationFunctionType Ac>
BroadcastAdd(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1310 inline void BroadcastAdd(int left_shift, const uint8* input1_data,
1311 const Dims<4>& input1_dims, int32 input1_offset,
1312 int32 input1_multiplier, int input1_shift,
1313 const uint8* input2_data, const Dims<4>& input2_dims,
1314 int32 input2_offset, int32 input2_multiplier,
1315 int input2_shift, int32 output_offset,
1316 int32 output_multiplier, int output_shift,
1317 int32 output_activation_min,
1318 int32 output_activation_max, uint8* output_data,
1319 const Dims<4>& output_dims) {
1320 constexpr int kReverseShift = -1;
1321 static_assert(Ac == FusedActivationFunctionType::kNone ||
1322 Ac == FusedActivationFunctionType::kRelu ||
1323 Ac == FusedActivationFunctionType::kRelu6 ||
1324 Ac == FusedActivationFunctionType::kRelu1,
1325 "");
1326 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1327 if (Ac == FusedActivationFunctionType::kNone) {
1328 TFLITE_DCHECK_EQ(output_activation_min, 0);
1329 TFLITE_DCHECK_EQ(output_activation_max, 255);
1330 }
1331
1332 tflite::ArithmeticParams op_params;
1333 op_params.left_shift = left_shift;
1334 op_params.input1_offset = input1_offset;
1335 op_params.input1_multiplier = input1_multiplier;
1336 op_params.input1_shift = kReverseShift * input1_shift;
1337 op_params.input2_offset = input2_offset;
1338 op_params.input2_multiplier = input2_multiplier;
1339 op_params.input2_shift = kReverseShift * input2_shift;
1340 op_params.output_offset = output_offset;
1341 op_params.output_multiplier = output_multiplier;
1342 op_params.output_shift = kReverseShift * output_shift;
1343 op_params.quantized_activation_min = output_activation_min;
1344 op_params.quantized_activation_max = output_activation_max;
1345 BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1346 DimsToShape(input2_dims), input2_data,
1347 DimsToShape(output_dims), output_data);
1348 }
1349
1350 template <FusedActivationFunctionType Ac>
Add(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1351 void Add(const float* input1_data, const Dims<4>& input1_dims,
1352 const float* input2_data, const Dims<4>& input2_dims,
1353 float* output_data, const Dims<4>& output_dims) {
1354 float output_activation_min, output_activation_max;
1355 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1356
1357 tflite::ArithmeticParams op_params;
1358 op_params.float_activation_min = output_activation_min;
1359 op_params.float_activation_max = output_activation_max;
1360 Add(op_params, DimsToShape(input1_dims), input1_data,
1361 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1362 output_data);
1363 }
1364
1365 template <typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1366 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
1367 const T* input2_data, const Dims<4>& input2_dims,
1368 T output_activation_min, T output_activation_max,
1369 T* output_data, const Dims<4>& output_dims) {
1370 tflite::ArithmeticParams op_params;
1371 op_params.float_activation_min = output_activation_min;
1372 op_params.float_activation_max = output_activation_max;
1373 BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1374 DimsToShape(input2_dims), input2_data,
1375 DimsToShape(output_dims), output_data);
1376 }
1377
1378 template <FusedActivationFunctionType Ac>
BroadcastAddFivefold(int y0,int y1,int y2,int y3,int y4,int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1379 inline void BroadcastAddFivefold(
1380 int y0, int y1, int y2, int y3, int y4, int left_shift,
1381 const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
1382 int32 input1_multiplier, int input1_shift, const uint8* input2_data,
1383 const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
1384 int input2_shift, int32 output_offset, int32 output_multiplier,
1385 int output_shift, int32 output_activation_min, int32 output_activation_max,
1386 uint8* output_data, const Dims<4>& output_dims) {
1387 constexpr int kReverseShift = -1;
1388 static_assert(Ac == FusedActivationFunctionType::kNone ||
1389 Ac == FusedActivationFunctionType::kRelu ||
1390 Ac == FusedActivationFunctionType::kRelu6 ||
1391 Ac == FusedActivationFunctionType::kRelu1,
1392 "");
1393 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1394 if (Ac == FusedActivationFunctionType::kNone) {
1395 TFLITE_DCHECK_EQ(output_activation_min, 0);
1396 TFLITE_DCHECK_EQ(output_activation_max, 255);
1397 }
1398 tflite::ArithmeticParams op_params;
1399 op_params.broadcast_category =
1400 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
1401 op_params.left_shift = left_shift;
1402 op_params.input1_offset = input1_offset;
1403 op_params.input1_multiplier = input1_multiplier;
1404 op_params.input1_shift = kReverseShift * input1_shift;
1405 op_params.input2_offset = input2_offset;
1406 op_params.input2_multiplier = input2_multiplier;
1407 op_params.input2_shift = kReverseShift * input2_shift;
1408 op_params.output_offset = output_offset;
1409 op_params.output_multiplier = output_multiplier;
1410 op_params.output_shift = kReverseShift * output_shift;
1411 op_params.quantized_activation_min = output_activation_min;
1412 op_params.quantized_activation_max = output_activation_max;
1413 op_params.broadcast_shape[4] = y0;
1414 op_params.broadcast_shape[3] = y1;
1415 op_params.broadcast_shape[2] = y2;
1416 op_params.broadcast_shape[1] = y3;
1417 op_params.broadcast_shape[0] = y4;
1418 BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
1419 DimsToShape(input2_dims), input2_data,
1420 DimsToShape(output_dims), output_data);
1421 }
1422
1423 // legacy, for compatibility with old checked-in code
1424 template <FusedActivationFunctionType Ac, typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1425 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
1426 const T* input2_data, const Dims<4>& input2_dims,
1427 T* output_data, const Dims<4>& output_dims) {
1428 T output_activation_min, output_activation_max;
1429 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1430
1431 BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
1432 output_activation_min, output_activation_max, output_data,
1433 output_dims);
1434 }
1435
1436 template <FusedActivationFunctionType Ac>
Add(const int16 * input1_data,const Dims<4> & input1_dims,int input1_shift,const int16 * input2_data,const Dims<4> & input2_dims,int input2_shift,int16 output_activation_min,int16 output_activation_max,int16 * output_data,const Dims<4> & output_dims)1437 inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
1438 int input1_shift, const int16* input2_data,
1439 const Dims<4>& input2_dims, int input2_shift,
1440 int16 output_activation_min, int16 output_activation_max,
1441 int16* output_data, const Dims<4>& output_dims) {
1442 static_assert(Ac == FusedActivationFunctionType::kNone ||
1443 Ac == FusedActivationFunctionType::kRelu ||
1444 Ac == FusedActivationFunctionType::kRelu6 ||
1445 Ac == FusedActivationFunctionType::kRelu1,
1446 "");
1447 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1448 if (Ac == FusedActivationFunctionType::kNone) {
1449 TFLITE_DCHECK_EQ(output_activation_min, -32768);
1450 TFLITE_DCHECK_EQ(output_activation_max, 32767);
1451 }
1452
1453 tflite::ArithmeticParams op_params;
1454 op_params.input1_shift = kReverseShift * input1_shift;
1455 op_params.input2_shift = kReverseShift * input2_shift;
1456 op_params.quantized_activation_min = output_activation_min;
1457 op_params.quantized_activation_max = output_activation_max;
1458 Add(op_params, DimsToShape(input1_dims), input1_data,
1459 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1460 output_data);
1461 }
1462
Sub(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1463 inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
1464 const float* input2_data, const Dims<4>& input2_dims,
1465 float* output_data, const Dims<4>& output_dims) {
1466 float output_activation_min, output_activation_max;
1467 GetActivationMinMax(FusedActivationFunctionType::kNone,
1468 &output_activation_min, &output_activation_max);
1469 tflite::ArithmeticParams op_params;
1470 op_params.float_activation_min = output_activation_min;
1471 op_params.float_activation_max = output_activation_max;
1472 Sub(op_params, DimsToShape(input1_dims), input1_data,
1473 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1474 output_data);
1475 }
1476
1477 template <typename T>
Sub(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1478 void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
1479 const Dims<4>& input2_dims, T* output_data,
1480 const Dims<4>& output_dims) {
1481 tflite::ArithmeticParams op_params;
1482 op_params.quantized_activation_min = std::numeric_limits<T>::min();
1483 op_params.quantized_activation_max = std::numeric_limits<T>::max();
1484 Sub(op_params, DimsToShape(input1_dims), input1_data,
1485 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1486 output_data);
1487 }
1488
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1489 inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
1490 int stride_width, int stride_height, int pad_width,
1491 int pad_height, int kwidth, int kheight,
1492 float output_activation_min,
1493 float output_activation_max, float* output_data,
1494 const Dims<4>& output_dims) {
1495 tflite::PoolParams params;
1496 params.stride_height = stride_height;
1497 params.stride_width = stride_width;
1498 params.filter_height = kheight;
1499 params.filter_width = kwidth;
1500 params.padding_values.height = pad_height;
1501 params.padding_values.width = pad_width;
1502 params.float_activation_min = output_activation_min;
1503 params.float_activation_max = output_activation_max;
1504 AveragePool(params, DimsToShape(input_dims), input_data,
1505 DimsToShape(output_dims), output_data);
1506 }
1507
1508 // Transitional version that will be moved shortly to legacy_reference_ops, as
1509 // part of RuntimeShape revisions.
BroadcastMul4DSlow(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1510 inline void BroadcastMul4DSlow(const uint8* input1_data,
1511 const Dims<4>& input1_dims, int32 input1_offset,
1512 const uint8* input2_data,
1513 const Dims<4>& input2_dims, int32 input2_offset,
1514 int32 output_offset, int32 output_multiplier,
1515 int output_shift, int32 output_activation_min,
1516 int32 output_activation_max, uint8* output_data,
1517 const Dims<4>& output_dims) {
1518 tflite::ArithmeticParams op_params;
1519 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1520 op_params.input1_offset = input1_offset;
1521 op_params.input2_offset = input2_offset;
1522 op_params.output_offset = output_offset;
1523 op_params.output_multiplier = output_multiplier;
1524 op_params.output_shift = output_shift;
1525
1526 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1527 DimsToShape(input2_dims), input2_data,
1528 DimsToShape(output_dims), output_data);
1529 }
1530
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1531 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
1532 int32 input1_offset, const uint8* input2_data,
1533 const Dims<4>& input2_dims, int32 input2_offset,
1534 int32 output_offset, int32 output_multiplier,
1535 int output_shift, int32 output_activation_min,
1536 int32 output_activation_max, uint8* output_data,
1537 const Dims<4>& output_dims) {
1538 BroadcastMul4DSlow(
1539 input1_data, input1_dims, input1_offset, input2_data, input2_dims,
1540 input2_offset, output_offset, output_multiplier,
1541 //
1542 kReverseShift * output_shift,
1543 //
1544 output_activation_min, output_activation_max, output_data, output_dims);
1545 }
1546
1547 // legacy, for compatibility with old checked-in code
1548 template <FusedActivationFunctionType Ac>
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1549 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
1550 int32 input1_offset, const uint8* input2_data,
1551 const Dims<4>& input2_dims, int32 input2_offset,
1552 int32 output_offset, int32 output_multiplier,
1553 int output_shift, int32 output_activation_min,
1554 int32 output_activation_max, uint8* output_data,
1555 const Dims<4>& output_dims) {
1556 BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
1557 input2_dims, input2_offset, output_offset, output_multiplier,
1558 output_shift, output_activation_min, output_activation_max,
1559 output_data, output_dims);
1560 }
1561
1562 // legacy, for compatibility with old checked-in code
1563 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)1564 void AveragePool(const float* input_data, const Dims<4>& input_dims,
1565 int stride_width, int stride_height, int pad_width,
1566 int pad_height, int kwidth, int kheight, float* output_data,
1567 const Dims<4>& output_dims) {
1568 float output_activation_min, output_activation_max;
1569 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1570
1571 AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
1572 pad_height, kwidth, kheight, output_activation_min,
1573 output_activation_max, output_data, output_dims);
1574 }
1575
1576 // legacy, for compatibility with old checked-in code
1577 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1578 void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
1579 int pad_width, int pad_height, int filter_width,
1580 int filter_height, float* output_data,
1581 const Dims<4>& output_dims) {
1582 AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1583 filter_width, filter_height, output_data, output_dims);
1584 }
1585
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1586 inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
1587 int stride_width, int stride_height, int pad_width,
1588 int pad_height, int filter_width, int filter_height,
1589 int32 output_activation_min,
1590 int32 output_activation_max, uint8* output_data,
1591 const Dims<4>& output_dims) {
1592 tflite::PoolParams params;
1593 params.stride_height = stride_height;
1594 params.stride_width = stride_width;
1595 params.filter_height = filter_height;
1596 params.filter_width = filter_width;
1597 params.padding_values.height = pad_height;
1598 params.padding_values.width = pad_width;
1599 params.quantized_activation_min = output_activation_min;
1600 params.quantized_activation_max = output_activation_max;
1601 AveragePool(params, DimsToShape(input_dims), input_data,
1602 DimsToShape(output_dims), output_data);
1603 }
1604
1605 // legacy, for compatibility with old checked-in code
1606 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1607 void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
1608 int stride_width, int stride_height, int pad_width,
1609 int pad_height, int filter_width, int filter_height,
1610 int32 output_activation_min, int32 output_activation_max,
1611 uint8* output_data, const Dims<4>& output_dims) {
1612 static_assert(Ac == FusedActivationFunctionType::kNone ||
1613 Ac == FusedActivationFunctionType::kRelu ||
1614 Ac == FusedActivationFunctionType::kRelu6 ||
1615 Ac == FusedActivationFunctionType::kRelu1,
1616 "");
1617 if (Ac == FusedActivationFunctionType::kNone) {
1618 TFLITE_DCHECK_EQ(output_activation_min, 0);
1619 TFLITE_DCHECK_EQ(output_activation_max, 255);
1620 }
1621 AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
1622 pad_height, filter_width, filter_height, output_activation_min,
1623 output_activation_max, output_data, output_dims);
1624 }
1625
1626 // legacy, for compatibility with old checked-in code
1627 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1628 void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
1629 int pad_width, int pad_height, int filter_width,
1630 int filter_height, int32 output_activation_min,
1631 int32 output_activation_max, uint8* output_data,
1632 const Dims<4>& output_dims) {
1633 AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1634 filter_width, filter_height, output_activation_min,
1635 output_activation_max, output_data, output_dims);
1636 }
1637
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1638 inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
1639 int stride_width, int stride_height, int pad_width,
1640 int pad_height, int kwidth, int kheight,
1641 float output_activation_min, float output_activation_max,
1642 float* output_data, const Dims<4>& output_dims) {
1643 tflite::PoolParams params;
1644 params.stride_height = stride_height;
1645 params.stride_width = stride_width;
1646 params.filter_height = kheight;
1647 params.filter_width = kwidth;
1648 params.padding_values.height = pad_height;
1649 params.padding_values.width = pad_width;
1650 params.float_activation_min = output_activation_min;
1651 params.float_activation_max = output_activation_max;
1652 MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1653 output_data);
1654 }
1655
1656 // legacy, for compatibility with old checked-in code
1657 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)1658 void MaxPool(const float* input_data, const Dims<4>& input_dims,
1659 int stride_width, int stride_height, int pad_width, int pad_height,
1660 int kwidth, int kheight, float* output_data,
1661 const Dims<4>& output_dims) {
1662 float output_activation_min, output_activation_max;
1663 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1664 MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
1665 pad_height, kwidth, kheight, output_activation_min,
1666 output_activation_max, output_data, output_dims);
1667 }
1668
1669 // legacy, for compatibility with old checked-in code
1670 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1671 void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
1672 int pad_width, int pad_height, int filter_width, int filter_height,
1673 float* output_data, const Dims<4>& output_dims) {
1674 MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1675 filter_width, filter_height, output_data, output_dims);
1676 }
1677
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1678 inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
1679 int stride_width, int stride_height, int pad_width,
1680 int pad_height, int filter_width, int filter_height,
1681 int32 output_activation_min, int32 output_activation_max,
1682 uint8* output_data, const Dims<4>& output_dims) {
1683 PoolParams params;
1684 params.stride_height = stride_height;
1685 params.stride_width = stride_width;
1686 params.filter_height = filter_height;
1687 params.filter_width = filter_width;
1688 params.padding_values.height = pad_height;
1689 params.padding_values.width = pad_width;
1690 params.quantized_activation_min = output_activation_min;
1691 params.quantized_activation_max = output_activation_max;
1692 MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1693 output_data);
1694 }
1695
1696 // legacy, for compatibility with old checked-in code
1697 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1698 void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
1699 int stride_width, int stride_height, int pad_width, int pad_height,
1700 int filter_width, int filter_height, int32 output_activation_min,
1701 int32 output_activation_max, uint8* output_data,
1702 const Dims<4>& output_dims) {
1703 static_assert(Ac == FusedActivationFunctionType::kNone ||
1704 Ac == FusedActivationFunctionType::kRelu ||
1705 Ac == FusedActivationFunctionType::kRelu6 ||
1706 Ac == FusedActivationFunctionType::kRelu1,
1707 "");
1708 if (Ac == FusedActivationFunctionType::kNone) {
1709 TFLITE_DCHECK_EQ(output_activation_min, 0);
1710 TFLITE_DCHECK_EQ(output_activation_max, 255);
1711 }
1712 MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
1713 pad_height, filter_width, filter_height, output_activation_min,
1714 output_activation_max, output_data, output_dims);
1715 }
1716
1717 // legacy, for compatibility with old checked-in code
1718 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1719 void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
1720 int pad_width, int pad_height, int filter_width, int filter_height,
1721 int32 output_activation_min, int32 output_activation_max,
1722 uint8* output_data, const Dims<4>& output_dims) {
1723 MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1724 filter_width, filter_height, output_activation_min,
1725 output_activation_max, output_data, output_dims);
1726 }
1727
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1728 inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
1729 int stride_width, int stride_height, int pad_width,
1730 int pad_height, int filter_width, int filter_height,
1731 float output_activation_min, float output_activation_max,
1732 float* output_data, const Dims<4>& output_dims) {
1733 PoolParams params;
1734 params.stride_height = stride_height;
1735 params.stride_width = stride_width;
1736 params.filter_height = filter_height;
1737 params.filter_width = filter_width;
1738 params.padding_values.height = pad_height;
1739 params.padding_values.width = pad_width;
1740 params.float_activation_min = output_activation_min;
1741 params.float_activation_max = output_activation_max;
1742 L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1743 output_data);
1744 }
1745
1746 // legacy, for compatibility with old checked-in code
1747 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1748 void L2Pool(const float* input_data, const Dims<4>& input_dims,
1749 int stride_width, int stride_height, int pad_width, int pad_height,
1750 int filter_width, int filter_height, float* output_data,
1751 const Dims<4>& output_dims) {
1752 float output_activation_min, output_activation_max;
1753 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1754 L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
1755 pad_height, filter_width, filter_height, output_activation_min,
1756 output_activation_max, output_data, output_dims);
1757 }
1758
1759 // legacy, for compatibility with old checked-in code
1760 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1761 void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
1762 int pad_width, int pad_height, int filter_width, int filter_height,
1763 float* output_data, const Dims<4>& output_dims) {
1764 L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1765 filter_width, filter_height, output_data, output_dims);
1766 }
1767
Softmax(const float * input_data,const Dims<4> & input_dims,float beta,float * output_data,const Dims<4> & output_dims)1768 inline void Softmax(const float* input_data, const Dims<4>& input_dims,
1769 float beta, float* output_data,
1770 const Dims<4>& output_dims) {
1771 Softmax(input_data, DimsToShape(input_dims), beta, output_data,
1772 DimsToShape(output_dims));
1773 }
1774
Softmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)1775 inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
1776 int32 input_beta_multiplier, int32 input_beta_left_shift,
1777 int diff_min, uint8* output_data,
1778 const Dims<4>& output_dims) {
1779 Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
1780 input_beta_left_shift, diff_min, output_data,
1781 DimsToShape(output_dims));
1782 }
1783
LogSoftmax(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1784 inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
1785 float* output_data, const Dims<4>& output_dims) {
1786 LogSoftmax(input_data, DimsToShape(input_dims), output_data,
1787 DimsToShape(output_dims));
1788 }
1789
LogSoftmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)1790 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
1791 int32 input_multiplier, int32 input_left_shift,
1792 int32 reverse_scaling_divisor,
1793 int32 reverse_scaling_right_shift, int diff_min,
1794 uint8* output_data, const Dims<4>& output_dims) {
1795 LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
1796 input_left_shift, reverse_scaling_divisor,
1797 reverse_scaling_right_shift, diff_min, output_data,
1798 DimsToShape(output_dims));
1799 }
1800
Logistic(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1801 inline void Logistic(const float* input_data, const Dims<4>& input_dims,
1802 float* output_data, const Dims<4>& output_dims) {
1803 Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1804 output_data);
1805 }
1806
Logistic(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)1807 inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
1808 int32 input_zero_point, int32 input_range_radius,
1809 int32 input_multiplier, int input_left_shift,
1810 uint8* output_data, const Dims<4>& output_dims) {
1811 Logistic(input_data, DimsToShape(input_dims), input_zero_point,
1812 input_range_radius, input_multiplier, input_left_shift, output_data,
1813 DimsToShape(output_dims));
1814 }
1815
Logistic(const int16 * input_data,const Dims<4> & input_dims,int16 * output_data,const Dims<4> & output_dims)1816 inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
1817 int16* output_data, const Dims<4>& output_dims) {
1818 Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1819 output_data);
1820 }
1821
Tanh(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1822 inline void Tanh(const float* input_data, const Dims<4>& input_dims,
1823 float* output_data, const Dims<4>& output_dims) {
1824 Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1825 output_data);
1826 }
1827
Tanh(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)1828 inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
1829 int32 input_zero_point, int32 input_range_radius,
1830 int32 input_multiplier, int input_left_shift,
1831 uint8* output_data, const Dims<4>& output_dims) {
1832 Tanh(input_data, DimsToShape(input_dims), input_zero_point,
1833 input_range_radius, input_multiplier, input_left_shift, output_data,
1834 DimsToShape(output_dims));
1835 }
1836
Tanh(const int16 * input_data,const Dims<4> & input_dims,int input_left_shift,int16 * output_data,const Dims<4> & output_dims)1837 inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
1838 int input_left_shift, int16* output_data,
1839 const Dims<4>& output_dims) {
1840 Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
1841 DimsToShape(output_dims));
1842 }
1843
1844 template <typename T>
DepthToSpace(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)1845 inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
1846 int block_size, T* output_data,
1847 const Dims<4>& output_dims) {
1848 tflite::DepthToSpaceParams op_params;
1849 op_params.block_size = block_size;
1850
1851 DepthToSpace(op_params, DimsToShape(input_dims), input_data,
1852 DimsToShape(output_dims), output_data);
1853 }
1854
1855 template <typename T>
SpaceToDepth(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)1856 inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
1857 int block_size, T* output_data,
1858 const Dims<4>& output_dims) {
1859 tflite::SpaceToDepthParams op_params;
1860 op_params.block_size = block_size;
1861
1862 SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
1863 DimsToShape(output_dims), output_data);
1864 }
1865
1866 template <typename T>
Mul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1867 inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
1868 const T* input2_data, const Dims<4>& input2_dims,
1869 T output_activation_min, T output_activation_max,
1870 T* output_data, const Dims<4>& output_dims) {
1871 tflite::ArithmeticParams op_params;
1872 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1873
1874 Mul(op_params, DimsToShape(input1_dims), input1_data,
1875 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1876 output_data);
1877 }
1878
1879 // legacy, for compatibility with old checked-in code
1880 template <FusedActivationFunctionType Ac>
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1881 void Mul(const float* input1_data, const Dims<4>& input1_dims,
1882 const float* input2_data, const Dims<4>& input2_dims,
1883 float* output_data, const Dims<4>& output_dims) {
1884 float output_activation_min, output_activation_max;
1885 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1886
1887 tflite::ArithmeticParams op_params;
1888 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1889
1890 Mul(op_params, DimsToShape(input1_dims), input1_data,
1891 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1892 output_data);
1893 }
1894
1895 template <typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1896 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
1897 const T* input2_data, const Dims<4>& input2_dims,
1898 T output_activation_min, T output_activation_max,
1899 T* output_data, const Dims<4>& output_dims) {
1900 tflite::ArithmeticParams op_params;
1901 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1902
1903 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1904 DimsToShape(input2_dims), input2_data,
1905 DimsToShape(output_dims), output_data);
1906 }
1907
1908 // legacy, for compatibility with old checked-in code
1909 template <FusedActivationFunctionType Ac, typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1910 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
1911 const T* input2_data, const Dims<4>& input2_dims,
1912 T* output_data, const Dims<4>& output_dims) {
1913 T output_activation_min, output_activation_max;
1914 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1915
1916 tflite::ArithmeticParams op_params;
1917 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1918
1919 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1920 DimsToShape(input2_dims), input2_data,
1921 DimsToShape(output_dims), output_data);
1922 }
1923
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int16 * output_data,const Dims<4> & output_dims)1924 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
1925 const int16* input2_data, const Dims<4>& input2_dims,
1926 int16* output_data, const Dims<4>& output_dims) {
1927 tflite::ArithmeticParams op_params;
1928 // No params in this version.
1929
1930 Mul(op_params, DimsToShape(input1_dims), input1_data,
1931 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1932 output_data);
1933 }
1934
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int32 output_offset,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1935 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
1936 const int16* input2_data, const Dims<4>& input2_dims,
1937 int32 output_offset, int32 output_activation_min,
1938 int32 output_activation_max, uint8* output_data,
1939 const Dims<4>& output_dims) {
1940 tflite::ArithmeticParams op_params;
1941 op_params.quantized_activation_min = output_activation_min;
1942 op_params.quantized_activation_max = output_activation_max;
1943 op_params.output_offset = output_offset;
1944
1945 Mul(op_params, DimsToShape(input1_dims), input1_data,
1946 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1947 output_data);
1948 }
1949
LocalResponseNormalization(const float * input_data,const Dims<4> & input_dims,int range,float bias,float alpha,float beta,float * output_data,const Dims<4> & output_dims)1950 inline void LocalResponseNormalization(const float* input_data,
1951 const Dims<4>& input_dims, int range,
1952 float bias, float alpha, float beta,
1953 float* output_data,
1954 const Dims<4>& output_dims) {
1955 tflite::LocalResponseNormalizationParams op_params;
1956 op_params.range = range;
1957 op_params.bias = bias;
1958 op_params.alpha = alpha;
1959 op_params.beta = beta;
1960
1961 LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
1962 DimsToShape(output_dims), output_data);
1963 }
1964
1965 template <typename SrcT, typename DstT>
Cast(const SrcT * input_data,const Dims<4> & input_dims,DstT * output_data,const Dims<4> & output_dims)1966 void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
1967 const Dims<4>& output_dims) {
1968 Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1969 output_data);
1970 }
1971
Floor(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1972 inline void Floor(const float* input_data, const Dims<4>& input_dims,
1973 float* output_data, const Dims<4>& output_dims) {
1974 Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1975 output_data);
1976 }
1977
1978 template <typename T>
ResizeBilinear(const T * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,T * output_data,const Dims<4> & output_dims,bool align_corners)1979 inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
1980 const int32* output_size_data,
1981 const Dims<4>& output_size_dims, T* output_data,
1982 const Dims<4>& output_dims, bool align_corners) {
1983 tflite::ResizeBilinearParams op_params;
1984 op_params.align_corners = align_corners;
1985 op_params.half_pixel_centers = false;
1986 ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
1987 DimsToShape(output_size_dims), output_size_data,
1988 DimsToShape(output_dims), output_data);
1989 }
1990
1991 // legacy, for compatibility with old checked-in code
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims)1992 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
1993 const int32* output_size_data,
1994 const Dims<4>& output_size_dims, float* output_data,
1995 const Dims<4>& output_dims) {
1996 ResizeBilinear<float>(input_data, input_dims, output_size_data,
1997 output_size_dims, output_data, output_dims,
1998 /*align_corners=*/false);
1999 }
2000
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims)2001 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
2002 const int32* output_size_data,
2003 const Dims<4>& output_size_dims, uint8* output_data,
2004 const Dims<4>& output_dims) {
2005 ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
2006 output_size_dims, output_data, output_dims,
2007 /*align_corners=*/false);
2008 }
2009
2010 template <typename T>
SpaceToBatchND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * paddings_data,const Dims<4> & paddings_dims,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)2011 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
2012 const int32* block_shape_data,
2013 const Dims<4>& block_shape_dims,
2014 const int32* paddings_data,
2015 const Dims<4>& paddings_dims, T* output_data,
2016 const Dims<4>& output_dims,
2017 const int32_t pad_value) {
2018 tflite::SpaceToBatchParams op_params;
2019 op_params.output_offset = pad_value;
2020
2021 SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
2022 DimsToShape(block_shape_dims), block_shape_data,
2023 DimsToShape(paddings_dims), paddings_data,
2024 DimsToShape(output_dims), output_data);
2025 }
2026
2027 template <typename T>
SpaceToBatchND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * paddings_data,const Dims<4> & paddings_dims,T * output_data,const Dims<4> & output_dims)2028 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
2029 const int32* block_shape_data,
2030 const Dims<4>& block_shape_dims,
2031 const int32* paddings_data,
2032 const Dims<4>& paddings_dims, T* output_data,
2033 const Dims<4>& output_dims) {
2034 tflite::SpaceToBatchParams op_params;
2035 op_params.output_offset = 0;
2036
2037 SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
2038 DimsToShape(block_shape_dims), block_shape_data,
2039 DimsToShape(paddings_dims), paddings_data,
2040 DimsToShape(output_dims), output_data);
2041 }
2042
2043 template <typename T>
BatchToSpaceND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * crops_data,const Dims<4> & crops_dims,T * output_data,const Dims<4> & output_dims)2044 inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
2045 const int32* block_shape_data,
2046 const Dims<4>& block_shape_dims,
2047 const int32* crops_data, const Dims<4>& crops_dims,
2048 T* output_data, const Dims<4>& output_dims) {
2049 BatchToSpaceND(DimsToShape(input_dims), input_data,
2050 DimsToShape(block_shape_dims), block_shape_data,
2051 DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
2052 output_data);
2053 }
2054
2055 // Legacy signature, function covered both Pad and PadV2.
2056 template <typename T>
PadV2(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const T pad_value)2057 inline void PadV2(const T* input_data, const Dims<4>& input_dims,
2058 const std::vector<int>& left_paddings,
2059 const std::vector<int>& right_paddings, T* output_data,
2060 const Dims<4>& output_dims, const T pad_value) {
2061 TFLITE_DCHECK_EQ(left_paddings.size(), 4);
2062 TFLITE_DCHECK_EQ(right_paddings.size(), 4);
2063 tflite::PadParams op_params;
2064 op_params.left_padding_count = 4;
2065 op_params.right_padding_count = 4;
2066 for (int i = 0; i < 4; ++i) {
2067 op_params.left_padding[i] = left_paddings[3 - i];
2068 op_params.right_padding[i] = right_paddings[3 - i];
2069 }
2070 // SetFloatOrInt(pad_value, &op_params.pad_value);
2071 const T pad_value_copy = pad_value;
2072
2073 Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
2074 DimsToShape(output_dims), output_data);
2075 }
2076
2077 // Old Pad that calls legacy PadV2.
2078 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)2079 inline void Pad(const T* input_data, const Dims<4>& input_dims,
2080 const std::vector<int>& left_paddings,
2081 const std::vector<int>& right_paddings, T* output_data,
2082 const Dims<4>& output_dims, const int32_t pad_value) {
2083 const T converted_pad_value = static_cast<T>(pad_value);
2084 PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
2085 output_dims, converted_pad_value);
2086 }
2087
2088 // Old Pad that only padded with 0.
2089 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims)2090 inline void Pad(const T* input_data, const Dims<4>& input_dims,
2091 const std::vector<int>& left_paddings,
2092 const std::vector<int>& right_paddings, T* output_data,
2093 const Dims<4>& output_dims) {
2094 const T pad_value = static_cast<T>(0);
2095 PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
2096 output_dims, pad_value);
2097 }
2098
2099 template <typename T>
TensorFlowMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)2100 void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
2101 const T* input2_data, T* output_data,
2102 const Dims<4>& output_dims) {
2103 Minimum(DimsToShape(input1_dims), input1_data, input2_data,
2104 DimsToShape(output_dims), output_data);
2105 }
2106
2107 template <typename T>
TensorFlowMaximum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)2108 void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
2109 const T* input2_data, T* output_data,
2110 const Dims<4>& output_dims) {
2111 Maximum(DimsToShape(input1_dims), input1_data, input2_data,
2112 DimsToShape(output_dims), output_data);
2113 }
2114
2115 template <typename T, typename Op>
TensorFlowMaximumMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims,Op op)2116 void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
2117 const T* input2_data, const Dims<4>& input2_dims,
2118 T* output_data, const Dims<4>& output_dims,
2119 Op op) {
2120 MaximumMinimumBroadcastSlow(DimsToShape(input1_dims), input1_data,
2121 DimsToShape(input2_dims), input2_data,
2122 DimsToShape(output_dims), output_data, op);
2123 }
2124
2125 template <typename T1, typename T2, typename T3>
ArgMax(const T3 * axis,const T1 * input_data,const tflite::Dims<4> & input_dims,T2 * output_data,const tflite::Dims<4> & output_dims)2126 void ArgMax(const T3* axis, const T1* input_data,
2127 const tflite::Dims<4>& input_dims, T2* output_data,
2128 const tflite::Dims<4>& output_dims) {
2129 // Assumes the input always has 4 dimensions, and therefore,
2130 // output always has three dimensions.
2131 auto output_shape = RuntimeShape(
2132 {output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]});
2133 // Another way to interpret this is that output_dims.sizes[4] is always 1.
2134 TFLITE_DCHECK_EQ(output_shape.FlatSize(),
2135 DimsToShape(output_dims).FlatSize());
2136 // Legacy path only supported this.
2137 TFLITE_DCHECK_EQ(axis[0], 3);
2138 ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape,
2139 output_data, std::greater<T1>());
2140 }
2141
2142 template <typename T1, typename T2, typename T3, typename Cmp>
ArgMinMax(const T3 * axis,const T1 * input_data,const Dims<4> & input_dims,T2 * output_data,const Dims<4> & output_dims,const Cmp & cmp)2143 void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
2144 T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
2145 ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
2146 output_data, cmp);
2147 }
2148
2149 template <typename T>
Pow(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)2150 inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
2151 const T* input2_data, const Dims<4>& input2_dims,
2152 T* output_data, const Dims<4>& output_dims) {
2153 Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
2154 input2_data, DimsToShape(output_dims), output_data);
2155 }
2156
2157 template <typename T>
BroadcastPow(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)2158 inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
2159 const T* input2_data, const Dims<4>& input2_dims,
2160 T* output_data, const Dims<4>& output_dims) {
2161 BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
2162 DimsToShape(input2_dims), input2_data,
2163 DimsToShape(output_dims), output_data);
2164 }
2165
2166 // R: Result type. T1: Input 1 type. T2: Input 2 type.
2167 template <typename R, typename T1, typename T2>
BroadcastBinaryFunction(const T1 * input1_data,const Dims<4> & input1_dims,const T2 * input2_data,const Dims<4> & input2_dims,R * output_data,const Dims<4> & output_dims,R (* func)(T1,T2))2168 inline void BroadcastBinaryFunction(const T1* input1_data,
2169 const Dims<4>& input1_dims,
2170 const T2* input2_data,
2171 const Dims<4>& input2_dims, R* output_data,
2172 const Dims<4>& output_dims,
2173 R (*func)(T1, T2)) {
2174 BroadcastBinaryFunction(DimsToShape(input1_dims), input1_data,
2175 DimsToShape(input2_dims), input2_data,
2176 DimsToShape(output_dims), output_data, func);
2177 }
2178
2179 // R: Result type. T1: Input 1 type. T2: Input 2 type.
2180 template <typename R, typename T1, typename T2>
BinaryFunction(const T1 * input1_data,const Dims<4> & input1_dims,const T2 * input2_data,const Dims<4> & input2_dims,R * output_data,const Dims<4> & output_dims,R (* func)(T1,T2))2181 inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims,
2182 const T2* input2_data, const Dims<4>& input2_dims,
2183 R* output_data, const Dims<4>& output_dims,
2184 R (*func)(T1, T2)) {
2185 BinaryFunction(DimsToShape(input1_dims), input1_data,
2186 DimsToShape(input2_dims), input2_data,
2187 DimsToShape(output_dims), output_data, func);
2188 }
2189
2190 template <typename T>
Slice(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & begin,const std::vector<int> & size,T * output_data,const Dims<4> & output_dims)2191 inline void Slice(const T* input_data, const Dims<4>& input_dims,
2192 const std::vector<int>& begin, const std::vector<int>& size,
2193 T* output_data, const Dims<4>& output_dims) {
2194 tflite::SliceParams op_params;
2195 op_params.begin_count = 4;
2196 op_params.size_count = 4;
2197 for (int i = 0; i < 4; ++i) {
2198 op_params.begin[i] = begin[3 - i];
2199 op_params.size[i] = size[3 - i];
2200 }
2201
2202 Slice(op_params, DimsToShape(input_dims), input_data,
2203 DimsToShape(output_dims), output_data);
2204 }
2205
2206 } // namespace reference_ops
2207 } // namespace tflite
2208 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
2209