1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
17
18 #include <stdint.h>
19 #include <sys/types.h>
20
21 #include <algorithm>
22 #include <array>
23 #include <cmath>
24 #include <cstring>
25 #include <functional>
26 #include <limits>
27 #include <memory>
28 #include <type_traits>
29
30 #include "Eigen/Core"
31 #include "fixedpoint/fixedpoint.h"
32 #include "ruy/profiler/instrumentation.h" // from @ruy
33 #include "tensorflow/lite/c/common.h"
34 #include "tensorflow/lite/kernels/internal/common.h"
35 #include "tensorflow/lite/kernels/internal/quantization_util.h"
36 #include "tensorflow/lite/kernels/internal/reference/add.h"
37 #include "tensorflow/lite/kernels/internal/reference/add_n.h"
38 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
39 #include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
40 #include "tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h"
41 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
42 #include "tensorflow/lite/kernels/internal/reference/cast.h"
43 #include "tensorflow/lite/kernels/internal/reference/ceil.h"
44 #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
45 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
46 #include "tensorflow/lite/kernels/internal/reference/conv.h"
47 #include "tensorflow/lite/kernels/internal/reference/depth_to_space.h"
48 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
49 #include "tensorflow/lite/kernels/internal/reference/div.h"
50 #include "tensorflow/lite/kernels/internal/reference/elu.h"
51 #include "tensorflow/lite/kernels/internal/reference/exp.h"
52 #include "tensorflow/lite/kernels/internal/reference/fill.h"
53 #include "tensorflow/lite/kernels/internal/reference/floor.h"
54 #include "tensorflow/lite/kernels/internal/reference/floor_div.h"
55 #include "tensorflow/lite/kernels/internal/reference/floor_mod.h"
56 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
57 #include "tensorflow/lite/kernels/internal/reference/gather.h"
58 #include "tensorflow/lite/kernels/internal/reference/hard_swish.h"
59 #include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
60 #include "tensorflow/lite/kernels/internal/reference/leaky_relu.h"
61 #include "tensorflow/lite/kernels/internal/reference/log_softmax.h"
62 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
63 #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
64 #include "tensorflow/lite/kernels/internal/reference/mul.h"
65 #include "tensorflow/lite/kernels/internal/reference/neg.h"
66 #include "tensorflow/lite/kernels/internal/reference/pad.h"
67 #include "tensorflow/lite/kernels/internal/reference/pooling.h"
68 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
69 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
70 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
71 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
72 #include "tensorflow/lite/kernels/internal/reference/requantize.h"
73 #include "tensorflow/lite/kernels/internal/reference/resize_bilinear.h"
74 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
75 #include "tensorflow/lite/kernels/internal/reference/round.h"
76 #include "tensorflow/lite/kernels/internal/reference/softmax.h"
77 #include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"
78 #include "tensorflow/lite/kernels/internal/reference/space_to_depth.h"
79 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
80 #include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
81 #include "tensorflow/lite/kernels/internal/reference/sub.h"
82 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
83 #include "tensorflow/lite/kernels/internal/reference/transpose.h"
84 #include "tensorflow/lite/kernels/internal/reference/transpose_conv.h"
85 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
86 #include "tensorflow/lite/kernels/internal/tensor.h"
87 #include "tensorflow/lite/kernels/internal/types.h"
88 namespace tflite {
89
90 namespace reference_ops {
91
92 template <typename T>
Relu(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)93 inline void Relu(const RuntimeShape& input_shape, const T* input_data,
94 const RuntimeShape& output_shape, T* output_data) {
95 const int flat_size = MatchingFlatSize(input_shape, output_shape);
96 for (int i = 0; i < flat_size; ++i) {
97 const T val = input_data[i];
98 const T lower = 0;
99 const T clamped = val < lower ? lower : val;
100 output_data[i] = clamped;
101 }
102 }
103
104 template <typename T>
Relu1(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)105 inline void Relu1(const RuntimeShape& input_shape, const T* input_data,
106 const RuntimeShape& output_shape, T* output_data) {
107 ruy::profiler::ScopeLabel label("Relu1 (not fused)");
108 const int flat_size = MatchingFlatSize(input_shape, output_shape);
109 for (int i = 0; i < flat_size; ++i) {
110 const T val = input_data[i];
111 const T upper = 1;
112 const T lower = -1;
113 const T clamped = val > upper ? upper : val < lower ? lower : val;
114 output_data[i] = clamped;
115 }
116 }
117
Relu6(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)118 inline void Relu6(const RuntimeShape& input_shape, const float* input_data,
119 const RuntimeShape& output_shape, float* output_data) {
120 ruy::profiler::ScopeLabel label("Relu6 (not fused)");
121 const int flat_size = MatchingFlatSize(input_shape, output_shape);
122 for (int i = 0; i < flat_size; ++i) {
123 const float val = input_data[i];
124 const float upper = 6;
125 const float lower = 0;
126 const float clamped = val > upper ? upper : val < lower ? lower : val;
127 output_data[i] = clamped;
128 }
129 }
130
131 template <typename T>
ReluX(const tflite::ReluParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)132 inline void ReluX(const tflite::ReluParams& params,
133 const RuntimeShape& input_shape, const T* input_data,
134 const RuntimeShape& output_shape, T* output_data) {
135 ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
136 const int flat_size = MatchingFlatSize(input_shape, output_shape);
137 for (int i = 0; i < flat_size; ++i) {
138 const int32 val = static_cast<int32_t>(input_data[i]);
139 int32 clamped = params.output_offset +
140 MultiplyByQuantizedMultiplier(val - params.input_offset,
141 params.output_multiplier,
142 params.output_shift);
143 clamped = std::max(params.quantized_activation_min, clamped);
144 clamped = std::min(params.quantized_activation_max, clamped);
145 output_data[i] = static_cast<T>(clamped);
146 }
147 }
148
149 template <typename T>
ReluX(const tflite::ActivationParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)150 inline void ReluX(const tflite::ActivationParams& params,
151 const RuntimeShape& input_shape, const T* input_data,
152 const RuntimeShape& output_shape, T* output_data) {
153 ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
154 const int flat_size = MatchingFlatSize(input_shape, output_shape);
155 const T max_value = params.quantized_activation_max;
156 const T min_value = params.quantized_activation_min;
157 for (int i = 0; i < flat_size; ++i) {
158 const T val = input_data[i];
159 const T clamped = val > max_value ? max_value
160 : val < min_value ? min_value
161 : val;
162 output_data[i] = clamped;
163 }
164 }
165
166 // TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
167 // dimensionality if the runtime code does a single loop over one dimension
168 // that handles broadcasting as the base case. The code generator would then
169 // generate max(D1, D2) nested for loops.
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)170 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
171 const RuntimeShape& unswitched_input1_shape,
172 const uint8* unswitched_input1_data,
173 const RuntimeShape& unswitched_input2_shape,
174 const uint8* unswitched_input2_data,
175 const RuntimeShape& output_shape,
176 uint8* output_data) {
177 ArithmeticParams switched_params = unswitched_params;
178 switched_params.input1_offset = unswitched_params.input2_offset;
179 switched_params.input2_offset = unswitched_params.input1_offset;
180
181 const bool use_unswitched =
182 unswitched_params.broadcast_category ==
183 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
184
185 const ArithmeticParams& params =
186 use_unswitched ? unswitched_params : switched_params;
187 const uint8* input1_data =
188 use_unswitched ? unswitched_input1_data : unswitched_input2_data;
189 const uint8* input2_data =
190 use_unswitched ? unswitched_input2_data : unswitched_input1_data;
191
192 // Fivefold nested loops. The second input resets its position for each
193 // iteration of the second loop. The first input resets its position at the
194 // beginning of the fourth loop. The innermost loop is an elementwise Mul of
195 // sections of the arrays.
196 uint8* output_data_ptr = output_data;
197 const uint8* input1_data_ptr = input1_data;
198 const uint8* input2_data_reset = input2_data;
199 int y0 = params.broadcast_shape[0];
200 int y1 = params.broadcast_shape[1];
201 int y2 = params.broadcast_shape[2];
202 int y3 = params.broadcast_shape[3];
203 int y4 = params.broadcast_shape[4];
204 for (int i0 = 0; i0 < y0; ++i0) {
205 const uint8* input2_data_ptr;
206 for (int i1 = 0; i1 < y1; ++i1) {
207 input2_data_ptr = input2_data_reset;
208 for (int i2 = 0; i2 < y2; ++i2) {
209 for (int i3 = 0; i3 < y3; ++i3) {
210 MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
211 output_data_ptr);
212 input2_data_ptr += y4;
213 output_data_ptr += y4;
214 }
215 input1_data_ptr += y4;
216 }
217 }
218 input2_data_reset = input2_data_ptr;
219 }
220 }
221
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)222 inline void Mul(const ArithmeticParams& params,
223 const RuntimeShape& input1_shape, const int16* input1_data,
224 const RuntimeShape& input2_shape, const int16* input2_data,
225 const RuntimeShape& output_shape, int16* output_data) {
226 ruy::profiler::ScopeLabel label("Mul/Int16");
227
228 const int flat_size =
229 MatchingElementsSize(input1_shape, input2_shape, output_shape);
230
231 for (int i = 0; i < flat_size; i++) {
232 // F0 uses 0 integer bits, range [-1, 1].
233 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
234
235 F0 unclamped_result =
236 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
237 output_data[i] = unclamped_result.raw();
238 }
239 }
240
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)241 inline void Mul(const ArithmeticParams& params,
242 const RuntimeShape& input1_shape, const int16* input1_data,
243 const RuntimeShape& input2_shape, const int16* input2_data,
244 const RuntimeShape& output_shape, uint8* output_data) {
245 ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
246 int32 output_offset = params.output_offset;
247 int32 output_activation_min = params.quantized_activation_min;
248 int32 output_activation_max = params.quantized_activation_max;
249 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
250
251 const int flat_size =
252 MatchingElementsSize(input1_shape, input2_shape, output_shape);
253
254 for (int i = 0; i < flat_size; i++) {
255 // F0 uses 0 integer bits, range [-1, 1].
256 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
257
258 F0 unclamped_result =
259 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
260 int16 rescaled_result =
261 gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
262 int16 clamped_result =
263 std::min<int16>(output_activation_max - output_offset, rescaled_result);
264 clamped_result =
265 std::max<int16>(output_activation_min - output_offset, clamped_result);
266 output_data[i] = output_offset + clamped_result;
267 }
268 }
269
Sub16(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16_t * input1_data,const RuntimeShape & input2_shape,const int16_t * input2_data,const RuntimeShape & output_shape,int16_t * output_data)270 inline void Sub16(const ArithmeticParams& params,
271 const RuntimeShape& input1_shape, const int16_t* input1_data,
272 const RuntimeShape& input2_shape, const int16_t* input2_data,
273 const RuntimeShape& output_shape, int16_t* output_data) {
274 ruy::profiler::ScopeLabel label("Sub/Int16");
275 const int input1_shift = params.input1_shift;
276 const int flat_size =
277 MatchingElementsSize(input1_shape, input2_shape, output_shape);
278 const int16 output_activation_min = params.quantized_activation_min;
279 const int16 output_activation_max = params.quantized_activation_max;
280
281 TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
282 TFLITE_DCHECK_LE(input1_shift, 0);
283 TFLITE_DCHECK_LE(params.input2_shift, 0);
284 const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
285 const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
286 const int input_right_shift =
287 input1_shift == 0 ? -params.input2_shift : -input1_shift;
288
289 if (input1_shift == 0) {
290 // F0 uses 0 integer bits, range [-1, 1].
291 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
292 for (int i = 0; i < flat_size; ++i) {
293 F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
294 F0 scaled_input = F0::FromRaw(
295 gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
296 F0 result = SaturatingSub(input_ready_scaled, scaled_input);
297 const int16 raw_output = result.raw();
298 const int16 clamped_output = std::min(
299 output_activation_max, std::max(output_activation_min, raw_output));
300 output_data[i] = clamped_output;
301 }
302 } else {
303 // F0 uses 0 integer bits, range [-1, 1].
304 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
305 for (int i = 0; i < flat_size; ++i) {
306 F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
307 F0 scaled_input = F0::FromRaw(
308 gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
309 F0 result = SaturatingSub(scaled_input, input_ready_scaled);
310 const int16 raw_output = result.raw();
311 const int16 clamped_output = std::min(
312 output_activation_max, std::max(output_activation_min, raw_output));
313 output_data[i] = clamped_output;
314 }
315 }
316 }
317
318 template <typename Scalar>
Pack(const PackParams & params,const RuntimeShape * const * input_shapes,const Scalar * const * input_data,const RuntimeShape & output_shape,Scalar * output_data)319 void Pack(const PackParams& params, const RuntimeShape* const* input_shapes,
320 const Scalar* const* input_data, const RuntimeShape& output_shape,
321 Scalar* output_data) {
322 ruy::profiler::ScopeLabel label("Pack");
323 const int dimensions = output_shape.DimensionsCount();
324 int axis = params.axis;
325 int inputs_count = params.inputs_count;
326
327 int outer_size = 1;
328 for (int i = 0; i < axis; i++) {
329 outer_size *= output_shape.Dims(i);
330 }
331 int copy_size = 1;
332 for (int i = params.axis + 1; i < dimensions; i++) {
333 copy_size *= output_shape.Dims(i);
334 }
335 TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
336
337 for (int i = 0; i < inputs_count; ++i) {
338 for (int k = 0; k < outer_size; k++) {
339 const Scalar* input_ptr = input_data[i] + copy_size * k;
340 int loc = k * inputs_count * copy_size + i * copy_size;
341 memcpy(output_data + loc, input_ptr, copy_size * sizeof(Scalar));
342 }
343 }
344 }
345
346 template <typename Scalar>
Unpack(const UnpackParams & params,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * const * output_datas)347 void Unpack(const UnpackParams& params, const RuntimeShape& input_shape,
348 const Scalar* input_data, const RuntimeShape& output_shape,
349 Scalar* const* output_datas) {
350 ruy::profiler::ScopeLabel label("Unpack");
351 const int dimensions = input_shape.DimensionsCount();
352 const int outputs_count = params.num_split;
353
354 int outer_size = 1;
355 int axis = params.axis;
356 if (axis < 0) {
357 axis += dimensions;
358 }
359 TFLITE_DCHECK_GE(axis, 0);
360 TFLITE_DCHECK_LT(axis, dimensions);
361 for (int i = 0; i < axis; ++i) {
362 outer_size *= input_shape.Dims(i);
363 }
364 int copy_size = 1;
365 for (int i = axis + 1; i < dimensions; ++i) {
366 copy_size *= input_shape.Dims(i);
367 }
368 TFLITE_DCHECK_EQ(output_shape.FlatSize(), copy_size * outer_size);
369
370 for (int i = 0; i < outputs_count; ++i) {
371 for (int k = 0; k < outer_size; k++) {
372 Scalar* output_ptr = output_datas[i] + copy_size * k;
373 int loc = k * outputs_count * copy_size + i * copy_size;
374 memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
375 }
376 }
377 }
378
379 template <typename Scalar>
PackWithScaling(const PackParams & params,const RuntimeShape * const * input_shapes,const uint8 * const * input_data,const RuntimeShape & output_shape,uint8 * output_data)380 void PackWithScaling(const PackParams& params,
381 const RuntimeShape* const* input_shapes,
382 const uint8* const* input_data,
383 const RuntimeShape& output_shape, uint8* output_data) {
384 ruy::profiler::ScopeLabel label("PackWithScaling");
385 const int dimensions = output_shape.DimensionsCount();
386 int axis = params.axis;
387 const int32* input_zeropoint = params.input_zeropoint;
388 const float* input_scale = params.input_scale;
389 int inputs_count = params.inputs_count;
390 const int32 output_zeropoint = params.output_zeropoint;
391 const float output_scale = params.output_scale;
392
393 int outer_size = 1;
394 for (int i = 0; i < axis; i++) {
395 outer_size *= output_shape.Dims(i);
396 }
397 int copy_size = 1;
398 for (int i = axis + 1; i < dimensions; i++) {
399 copy_size *= output_shape.Dims(i);
400 }
401 TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
402
403 Scalar* output_ptr = output_data;
404 const float inverse_output_scale = 1.f / output_scale;
405 for (int k = 0; k < outer_size; k++) {
406 for (int i = 0; i < inputs_count; ++i) {
407 if (input_zeropoint[i] == output_zeropoint &&
408 input_scale[i] == output_scale) {
409 memcpy(output_ptr, input_data[i] + k * copy_size,
410 copy_size * sizeof(Scalar));
411 } else {
412 assert(false);
413 const float scale = input_scale[i] * inverse_output_scale;
414 const float bias = -input_zeropoint[i] * scale;
415 auto input_ptr = input_data[i];
416 for (int j = 0; j < copy_size; ++j) {
417 const int32_t value =
418 static_cast<int32_t>(std::round(input_ptr[j] * scale + bias)) +
419 output_zeropoint;
420 output_ptr[j] =
421 static_cast<uint8_t>(std::max(std::min(255, value), 0));
422 }
423 }
424 output_ptr += copy_size;
425 }
426 }
427 }
428
429 template <typename Scalar>
DepthConcatenation(const ConcatenationParams & params,const RuntimeShape * const * input_shapes,const Scalar * const * input_data,const RuntimeShape & output_shape,Scalar * output_data)430 void DepthConcatenation(const ConcatenationParams& params,
431 const RuntimeShape* const* input_shapes,
432 const Scalar* const* input_data,
433 const RuntimeShape& output_shape, Scalar* output_data) {
434 ruy::profiler::ScopeLabel label("DepthConcatenation");
435 auto params_copy = params;
436 params_copy.axis = 3;
437 Concatenation(params_copy, input_shapes, input_data, output_shape,
438 output_data);
439 }
440
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data)441 inline void LstmCell(
442 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
443 const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
444 const float* prev_activ_data, const RuntimeShape& weights_shape,
445 const float* weights_data, const RuntimeShape& unextended_bias_shape,
446 const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
447 const float* prev_state_data,
448 const RuntimeShape& unextended_output_state_shape, float* output_state_data,
449 const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
450 const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
451 const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
452 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
453 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
454 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
455 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
456 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
457 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
458 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
459 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
460 const RuntimeShape input_shape =
461 RuntimeShape::ExtendedShape(4, unextended_input_shape);
462 const RuntimeShape prev_activ_shape =
463 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
464 const RuntimeShape bias_shape =
465 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
466 const RuntimeShape prev_state_shape =
467 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
468 const RuntimeShape output_state_shape =
469 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
470 const RuntimeShape output_activ_shape =
471 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
472 const RuntimeShape concat_temp_shape =
473 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
474 const RuntimeShape activ_temp_shape =
475 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
476 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
477
478 const int weights_dim_count = weights_shape.DimensionsCount();
479 const int batches =
480 MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
481 output_state_shape, 0, output_activ_shape, 0);
482 const int height =
483 MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
484 output_state_shape, 1, output_activ_shape, 1);
485 const int width =
486 MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
487 output_state_shape, 2, output_activ_shape, 2);
488 const int input_depth = input_shape.Dims(3);
489 const int prev_activ_depth = prev_activ_shape.Dims(3);
490 const int total_input_depth = prev_activ_depth + input_depth;
491 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
492 total_input_depth);
493 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
494 const int intern_activ_depth =
495 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
496 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
497 intern_activ_depth * total_input_depth);
498 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
499 const int output_depth =
500 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
501 3, output_activ_shape, 3);
502 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
503
504 // Concatenate prev_activ and input data together
505 std::vector<float const*> concat_input_arrays_data;
506 std::vector<RuntimeShape const*> concat_input_arrays_shapes;
507 concat_input_arrays_data.push_back(input_data);
508 concat_input_arrays_data.push_back(prev_activ_data);
509 concat_input_arrays_shapes.push_back(&input_shape);
510 concat_input_arrays_shapes.push_back(&prev_activ_shape);
511 tflite::ConcatenationParams concat_params;
512 concat_params.axis = 3;
513 concat_params.inputs_count = concat_input_arrays_data.size();
514 Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
515 &(concat_input_arrays_data[0]), concat_temp_shape,
516 concat_temp_data);
517
518 // Fully connected
519 tflite::FullyConnectedParams fc_params;
520 fc_params.float_activation_min = std::numeric_limits<float>::lowest();
521 fc_params.float_activation_max = std::numeric_limits<float>::max();
522 FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
523 weights_data, bias_shape, bias_data, activ_temp_shape,
524 activ_temp_data);
525
526 // Memory state update (the LSTM "guts")
527 for (int b = 0; b < batches; ++b) {
528 for (int w = 0; w < width; ++w) {
529 for (int h = 0; h < height; ++h) {
530 for (int c = 0; c < output_depth; ++c) {
531 const float input_gate =
532 1.f /
533 (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
534 0 * output_depth + c)]));
535 const float new_input = std::tanh(activ_temp_data[Offset(
536 activ_temp_shape, b, h, w, 1 * output_depth + c)]);
537 const float forget_gate =
538 1.f /
539 (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
540 2 * output_depth + c)]));
541 const float output_gate =
542 1.f /
543 (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
544 3 * output_depth + c)]));
545 const float new_state =
546 input_gate * new_input +
547 forget_gate *
548 prev_state_data[Offset(prev_state_shape, b, h, w, c)];
549 output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
550 output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
551 output_gate * std::tanh(new_state);
552 }
553 }
554 }
555 }
556 }
557
558 // Quantized LSTM cell implementation.
559 // The quantization of the input, output arrays is as follows:
560 // - The input activations are quantized as uint8 on the interval
561 // [-1, 127/128].
562 // The rationale for that is that is the natural interval for output
563 // activations (see next point) and these need to be concatenated together.
564 // We could accommodate different ranges by re-scaling, but we empirically
565 // found that setting the input activations range to be [-1, 127/128] in the
566 // first place, removing the need for re-scaling, greatly improves accuracy.
567 // - The output activations are quantized as uint8 on the interval
568 // [-1, 127/128].
569 // The rationale for that is that the definition of a LSTM cell makes them
570 // intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128]
571 // makes for simpler, more accurate fixed-point arithmetic.
572 // - The output-at-previous-timestep state array is obviously quantized as
573 // the output activations.
574 // - The internal LSTM memory (not the output-at-previous-timestep, the other
575 // internal state array) is int16-quantized and may use any power-of-two,
576 // symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call
577 // StateIntegerBits below, see the below discussion of that template
578 // parameter ("The StateIntegerBits template parameter").
579 // - The output of the internal fully-connected node is int16-quantized
580 // on the interval [-8, 8 * 32767/32768], the rationale for which is
581 // explained just below ("Why [-8, 8] for fully-connected output?").
582 //
583 //
584 // === The StateIntegerBits template parameter ===
585 //
586 // The StateIntegerBits template parameter controls the fixed-point format used
587 // to represent the internal memory of the LSTM cell (not the
588 // output-at-previous-timestep, the other internal state array). It's currently
589 // a template parameter so that the model can control that. The most typical
590 // value for StateIntegerBits is 4. Other plausible values are anywhere between
591 // 3 and 5. We might eventually standardize on a single supported value, e.g. 4,
592 // and drop that template parameter. The reason why it can't be a runtime
593 // parameter is that this controls the fixed-point format used, i.e. we need to
594 // generate actually different code based on it. In particular, we generate code
595 // for a fixed-point tanh() implementation for that format, which internally
596 // uses a fixed-point exp() implementation, which internally uses a
597 // barrel-shifter with a number of steps that depends on StateIntegerBits.
598 // Another consequence of that is that a higher value of StateIntegerBits
599 // results in a more expensive implementation (more barrel shifter steps
600 // needed).
601 //
602 //
603 // === Why [-8, 8] for fully-connected output? ===
604 //
605 // This array is only fed to Logistic and Tanh functions, for which
606 // the quantized implementation will want to use fixed-point arithmetic,
607 // requiring a power-of-two representation interval. Thus, we should right
608 // away quantize this array to a power-of-two interval; otherwise,
609 // implementation will need to rescale that, losing any benefit that a tighter
610 // representation interval might otherwise yield, while introducing some
611 // numerical error and computational overhead.
612 //
613 // Now, Logistic and Tanh
614 // are nearly constant (nearly equal to their horizontal asymptotes)
615 // outside of a small bounded interval around 0:
616 //
617 // Logistic(4) = 1 - 1.8e-2 Tanh(4) = 1 - 6.7e-4
618 // Logistic(8) = 1 - 3.4e-4 Tanh(8) = 1 - 2.3e-7
619 // Logistic(16) = 1 - 1.1e-7 Tanh(16) = 1 - 2.5e-14
620 //
621 // From this, we see that clamping to [-4, 4] would be too inaccurate
622 // (the error of 1.8e-2 on Logistic would be felt even in 8bit precision)
623 // while clamping to [-16, 16] would make no difference even in float32.
624 // However, for a fixed-point implementation in 16-bit integers, using 5
625 // integer bits to represent the [-16, 16] range would leave only 11
626 // fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive
627 // representable values. Notice that is higher than the
628 // worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic.
629 // Using [-8, 8] thus seems like the better compromise overall, enjoying
630 // an increment of 2.4e-4 between representable values and a worst-case
631 // clamping error of 3.4e-4, both better than the increment of 4.9e-4 with
632 // [-16, 16].
633 //
634 // Moreover, all other things being equal, it is nice to choose the narrower
635 // representation range, as that makes the implementation of fixed-point
636 // math functions a little cheaper (each integer bit requires an additional
637 // barrel-shifter atep in the implementation of exp(-x)). That is further
638 // reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make
639 // sense for 32-bit float or 32-bit fixed-point quantization, but we are
640 // aiming for 16-bit fixed-point quantization of these internal nodes here.
641 //
642 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,void * gemmlowp_context)643 inline void LstmCell(const LstmCellParams& params,
644 const RuntimeShape& unextended_input_shape,
645 const uint8* input_data_uint8,
646 const RuntimeShape& unextended_prev_activ_shape,
647 const uint8* prev_activ_data_uint8,
648 const RuntimeShape& weights_shape,
649 const uint8* weights_data_uint8,
650 const RuntimeShape& unextended_bias_shape,
651 const int32* bias_data_int32,
652 const RuntimeShape& unextended_prev_state_shape,
653 const int16* prev_state_data_int16,
654 const RuntimeShape& unextended_output_state_shape,
655 int16* output_state_data_int16,
656 const RuntimeShape& unextended_output_activ_shape,
657 uint8* output_activ_data_uint8,
658 const RuntimeShape& unextended_concat_temp_shape,
659 uint8* concat_temp_data_uint8,
660 const RuntimeShape& unextended_activ_temp_shape,
661 int16* activ_temp_data_int16, void* gemmlowp_context) {
662 (void)gemmlowp_context; // only used in optimized code.
663 int32 weights_zero_point = params.weights_zero_point;
664 int32 accum_multiplier = params.accum_multiplier;
665 int accum_shift = params.accum_shift;
666 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
667 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
668 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
669 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
670 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
671 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
672 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
673 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
674 const RuntimeShape input_shape =
675 RuntimeShape::ExtendedShape(4, unextended_input_shape);
676 const RuntimeShape prev_activ_shape =
677 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
678 const RuntimeShape bias_shape =
679 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
680 const RuntimeShape prev_state_shape =
681 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
682 const RuntimeShape output_state_shape =
683 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
684 const RuntimeShape output_activ_shape =
685 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
686 const RuntimeShape concat_temp_shape =
687 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
688 const RuntimeShape activ_temp_shape =
689 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
690 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
691
692 // Gather dimensions information, and perform consistency checks.
693 const int weights_dim_count = weights_shape.DimensionsCount();
694 const int outer_size = MatchingFlatSizeSkipDim(
695 input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
696 output_activ_shape);
697 const int input_depth = input_shape.Dims(3);
698 const int prev_activ_depth = prev_activ_shape.Dims(3);
699 const int total_input_depth = prev_activ_depth + input_depth;
700 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
701 total_input_depth);
702 const int intern_activ_depth =
703 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
704 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
705 intern_activ_depth * total_input_depth);
706 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
707 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
708 const int output_depth =
709 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
710 3, output_activ_shape, 3);
711 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
712 const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
713 const int fc_output_depth =
714 MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
715 const int fc_accum_depth = total_input_depth;
716 TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
717
718 // Depth-concatenate prev_activ and input data together.
719 uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
720 prev_activ_data_uint8};
721 const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
722 &prev_activ_shape};
723 tflite::ConcatenationParams concat_params;
724 concat_params.axis = 3;
725 concat_params.inputs_count = 2;
726 Concatenation(concat_params, concat_input_arrays_shapes,
727 concat_input_arrays_data, concat_temp_shape,
728 concat_temp_data_uint8);
729
730 // Implementation of the fully connected node inside the LSTM cell.
731 // The operands are 8-bit integers, the accumulators are internally 32bit
732 // integers, and the output is 16-bit fixed-point with 3 integer bits so
733 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
734 // is explained in the function comment above.
735 for (int b = 0; b < fc_batches; ++b) {
736 for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
737 // Internal accumulation.
738 // Initialize accumulator with the bias-value.
739 int32 accum = bias_data_int32[out_c];
740 // Accumulation loop.
741 for (int d = 0; d < fc_accum_depth; ++d) {
742 int16 input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
743 int16 weights_val =
744 weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
745 accum += input_val * weights_val;
746 }
747 // Down-scale the final int32 accumulator to the scale used by our
748 // (16-bit, using 3 integer bits) fixed-point format. The quantized
749 // multiplier and shift here have been pre-computed offline
750 // (e.g. by toco).
751 accum =
752 MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
753 // Saturate, cast to int16, and store to the temporary activations array.
754 accum = std::max(-32768, std::min(32767, accum));
755 activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
756 }
757 }
758
759 // Rest of the LSTM cell: tanh and logistic math functions, and some adds
760 // and muls, all done in 16-bit fixed-point.
761 for (int b = 0; b < outer_size; ++b) {
762 for (int c = 0; c < output_depth; ++c) {
763 // Define the fixed-point data types that we will use here. All use
764 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
765 // They only differ by the number of integral vs. fractional bits,
766 // determining the range of values that they can represent.
767 //
768 // F0 uses 0 integer bits, range [-1, 1].
769 // This is the return type of math functions such as tanh, logistic,
770 // whose range is in [-1, 1].
771 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
772 // F3 uses 3 integer bits, range [-8, 8].
773 // This is the range of the previous fully-connected node's output,
774 // which is our input here.
775 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
776 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
777 // 2^StateIntegerBits]. It's used to represent the internal state, whose
778 // number of integer bits is currently dictated by the model. See comment
779 // on the StateIntegerBits template parameter above.
780 using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
781 // Implementation of input gate, using fixed-point logistic function.
782 F3 input_gate_input = F3::FromRaw(
783 activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
784 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
785 // Implementation of input modulation gate, using fixed-point tanh
786 // function.
787 F3 input_modulation_gate_input = F3::FromRaw(
788 activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
789 F0 input_modulation_gate_output =
790 gemmlowp::tanh(input_modulation_gate_input);
791 // Implementation of forget gate, using fixed-point logistic function.
792 F3 forget_gate_input = F3::FromRaw(
793 activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
794 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
795 // Implementation of output gate, using fixed-point logistic function.
796 F3 output_gate_input = F3::FromRaw(
797 activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
798 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
799 // Implementation of internal multiplication nodes, still in fixed-point.
800 F0 input_times_input_modulation =
801 input_gate_output * input_modulation_gate_output;
802 FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]);
803 FS prev_state_times_forget_state = forget_gate_output * prev_state;
804 // Implementation of internal addition node, saturating.
805 FS new_state = gemmlowp::SaturatingAdd(
806 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
807 prev_state_times_forget_state);
808 // Implementation of last internal Tanh node, still in fixed-point.
809 // Since a Tanh fixed-point implementation is specialized for a given
810 // number or integer bits, and each specialization can have a substantial
811 // code size, and we already used above a Tanh on an input with 3 integer
812 // bits, and per the table in the above function comment there is no
813 // significant accuracy to be lost by clamping to [-8, +8] for a
814 // 3-integer-bits representation, let us just do that. This helps people
815 // porting this to targets where code footprint must be minimized.
816 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
817 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
818 // Store the new internal state back to memory, as 16-bit integers.
819 // Note: here we store the original value with StateIntegerBits, not
820 // the rescaled 3-integer-bits value fed to tanh.
821 output_state_data_int16[b * output_depth + c] = new_state.raw();
822 // Down-scale the output activations to 8-bit integers, saturating,
823 // and store back to memory.
824 int16 rescaled_output_activ =
825 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
826 int16 clamped_output_activ =
827 std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
828 output_activ_data_uint8[b * output_depth + c] =
829 128 + clamped_output_activ;
830 }
831 }
832 }
833
834 template <typename Scalar>
Split(const SplitParams & params,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape * const * output_shapes,Scalar * const * output_data)835 void Split(const SplitParams& params, const RuntimeShape& input_shape,
836 const Scalar* input_data, const RuntimeShape* const* output_shapes,
837 Scalar* const* output_data) {
838 ruy::profiler::ScopeLabel label("Split");
839 const int split_dimensions = input_shape.DimensionsCount();
840 int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
841 int outputs_count = params.num_split;
842 TFLITE_DCHECK_LT(axis, split_dimensions);
843
844 int64_t split_size = 0;
845 for (int i = 0; i < outputs_count; i++) {
846 TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
847 for (int j = 0; j < split_dimensions; j++) {
848 if (j != axis) {
849 MatchingDim(*output_shapes[i], j, input_shape, j);
850 }
851 }
852 split_size += output_shapes[i]->Dims(axis);
853 }
854 TFLITE_DCHECK_EQ(split_size, input_shape.Dims(axis));
855 int64_t outer_size = 1;
856 for (int i = 0; i < axis; ++i) {
857 outer_size *= input_shape.Dims(i);
858 }
859 // For all output arrays,
860 // FlatSize() = outer_size * Dims(axis) * base_inner_size;
861 int64_t base_inner_size = 1;
862 for (int i = axis + 1; i < split_dimensions; ++i) {
863 base_inner_size *= input_shape.Dims(i);
864 }
865
866 const Scalar* input_ptr = input_data;
867 for (int k = 0; k < outer_size; k++) {
868 for (int i = 0; i < outputs_count; ++i) {
869 const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
870 memcpy(output_data[i] + k * copy_size, input_ptr,
871 copy_size * sizeof(Scalar));
872 input_ptr += copy_size;
873 }
874 }
875 }
876
NodeOffset(int b,int h,int w,int height,int width)877 inline int NodeOffset(int b, int h, int w, int height, int width) {
878 return (b * height + h) * width + w;
879 }
880
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)881 inline void LocalResponseNormalization(
882 const tflite::LocalResponseNormalizationParams& op_params,
883 const RuntimeShape& input_shape, const float* input_data,
884 const RuntimeShape& output_shape, float* output_data) {
885 const int trailing_dim = input_shape.DimensionsCount() - 1;
886 const int outer_size =
887 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
888 const int depth =
889 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
890
891 for (int i = 0; i < outer_size; ++i) {
892 for (int c = 0; c < depth; ++c) {
893 const int begin_input_c = std::max(0, c - op_params.range);
894 const int end_input_c = std::min(depth, c + op_params.range);
895 float accum = 0.f;
896 for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
897 const float input_val = input_data[i * depth + input_c];
898 accum += input_val * input_val;
899 }
900 const float multiplier =
901 std::pow(op_params.bias + op_params.alpha * accum, -op_params.beta);
902 output_data[i * depth + c] = input_data[i * depth + c] * multiplier;
903 }
904 }
905 }
906
Dequantize(const RuntimeShape & input_shape,const Eigen::half * input_data,const RuntimeShape & output_shape,float * output_data)907 inline void Dequantize(const RuntimeShape& input_shape,
908 const Eigen::half* input_data,
909 const RuntimeShape& output_shape, float* output_data) {
910 const int flat_size = MatchingFlatSize(input_shape, output_shape);
911 for (int i = 0; i < flat_size; i++) {
912 output_data[i] = static_cast<float>(input_data[i]);
913 }
914 }
915
FakeQuant(const tflite::FakeQuantParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)916 inline void FakeQuant(const tflite::FakeQuantParams& op_params,
917 const RuntimeShape& input_shape, const float* input_data,
918 const RuntimeShape& output_shape, float* output_data) {
919 ruy::profiler::ScopeLabel label("FakeQuant");
920 float rmin = op_params.minmax.min;
921 float rmax = op_params.minmax.max;
922 int num_bits = op_params.num_bits;
923 // 0 should always be a representable value. Let's assume that the initial
924 // min,max range contains 0.
925 TFLITE_DCHECK_LE(rmin, 0.0f);
926 TFLITE_DCHECK_GE(rmax, 0.0f);
927 TFLITE_DCHECK_LT(rmin, rmax);
928
929 // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor.
930 int quant_min = 0;
931 int quant_max = (1 << num_bits) - 1;
932 float nudged_min, nudged_max, nudged_scale;
933 NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
934 &nudged_max, &nudged_scale);
935 const int flat_size = MatchingFlatSize(input_shape, output_shape);
936 FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data,
937 output_data, flat_size);
938 }
939
940 // Common subroutine for both `GatherNd` and `GatherNdString`.
941 struct GatherNdHelperResult {
942 int n_slices;
943 int slice_size;
944 int indices_nd;
945 std::vector<int> dims_to_count;
946 };
947
948 // Returns common values being used on both `GatherNd` and `GatherNdString`.
GatherNdHelper(const RuntimeShape & params_shape,const RuntimeShape & indices_shape)949 inline GatherNdHelperResult GatherNdHelper(const RuntimeShape& params_shape,
950 const RuntimeShape& indices_shape) {
951 GatherNdHelperResult ret;
952 ret.n_slices = 1;
953 ret.slice_size = 1;
954 const int indices_dims = indices_shape.DimensionsCount();
955 ret.indices_nd = indices_shape.Dims(indices_dims - 1);
956 const int params_dims = params_shape.DimensionsCount();
957 for (int i = 0; i < indices_dims - 1; ++i) {
958 ret.n_slices *= indices_shape.Dims(i);
959 }
960 for (int i = ret.indices_nd; i < params_dims; ++i) {
961 ret.slice_size *= params_shape.Dims(i);
962 }
963
964 int remain_flat_size = params_shape.FlatSize();
965 ret.dims_to_count = std::vector<int>(ret.indices_nd, 0);
966 for (int i = 0; i < ret.indices_nd; ++i) {
967 ret.dims_to_count[i] = remain_flat_size / params_shape.Dims(i);
968 remain_flat_size = ret.dims_to_count[i];
969 }
970
971 return ret;
972 }
973
974 template <typename ParamsT, typename IndicesT = int32>
GatherNd(const RuntimeShape & params_shape,const ParamsT * params_data,const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & output_shape,ParamsT * output_data)975 inline void GatherNd(const RuntimeShape& params_shape,
976 const ParamsT* params_data,
977 const RuntimeShape& indices_shape,
978 const IndicesT* indices_data,
979 const RuntimeShape& output_shape, ParamsT* output_data) {
980 ruy::profiler::ScopeLabel label("GatherNd");
981
982 const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
983 for (int i = 0; i < res.n_slices; ++i) {
984 int from_pos = 0;
985 for (int j = 0; j < res.indices_nd; ++j) {
986 from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
987 }
988 std::memcpy(output_data + i * res.slice_size, params_data + from_pos,
989 sizeof(ParamsT) * res.slice_size);
990 }
991 }
992
993 #ifndef TF_LITE_STATIC_MEMORY
994 template <typename IndicesT = int32>
GatherNdString(const RuntimeShape & params_shape,const TfLiteTensor * params_data,const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & output_shape,TfLiteTensor * output_data)995 inline void GatherNdString(const RuntimeShape& params_shape,
996 const TfLiteTensor* params_data,
997 const RuntimeShape& indices_shape,
998 const IndicesT* indices_data,
999 const RuntimeShape& output_shape,
1000 TfLiteTensor* output_data) {
1001 ruy::profiler::ScopeLabel label("GatherNdString");
1002
1003 const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
1004 DynamicBuffer buffer;
1005 for (int i = 0; i < res.n_slices; ++i) {
1006 int from_pos = 0;
1007 for (int j = 0; j < res.indices_nd; ++j) {
1008 from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
1009 }
1010 for (int j = 0; j < res.slice_size; ++j) {
1011 buffer.AddString(GetString(params_data, from_pos + j));
1012 }
1013 }
1014 buffer.WriteToTensor(output_data, /*new_shape=*/nullptr);
1015 }
1016 #endif
1017
1018 template <typename IndicesT, typename UpdatesT>
ScatterNd(const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & updates_shape,const UpdatesT * updates_data,const RuntimeShape & output_shape,UpdatesT * output_data)1019 inline void ScatterNd(const RuntimeShape& indices_shape,
1020 const IndicesT* indices_data,
1021 const RuntimeShape& updates_shape,
1022 const UpdatesT* updates_data,
1023 const RuntimeShape& output_shape, UpdatesT* output_data) {
1024 ruy::profiler::ScopeLabel label("ScatterNd");
1025
1026 int n_slices = 1;
1027 int slice_size = 1;
1028 const int outer_dims = indices_shape.DimensionsCount() - 1;
1029 const int indices_nd = indices_shape.Dims(outer_dims);
1030 const int updates_dims = updates_shape.DimensionsCount();
1031 for (int i = 0; i < outer_dims; ++i) {
1032 n_slices *= indices_shape.Dims(i);
1033 }
1034 for (int i = outer_dims; i < updates_dims; ++i) {
1035 slice_size *= updates_shape.Dims(i);
1036 }
1037
1038 int output_flat_size = output_shape.FlatSize();
1039 int remain_flat_size = output_flat_size;
1040 std::vector<int> dims_to_count(indices_nd, 0);
1041 for (int i = 0; i < indices_nd; ++i) {
1042 dims_to_count[i] = remain_flat_size / output_shape.Dims(i);
1043 remain_flat_size = dims_to_count[i];
1044 }
1045
1046 memset(output_data, 0, sizeof(UpdatesT) * output_flat_size);
1047 for (int i = 0; i < n_slices; ++i) {
1048 int to_pos = 0;
1049 for (int j = 0; j < indices_nd; ++j) {
1050 IndicesT idx = indices_data[i * indices_nd + j];
1051 TFLITE_DCHECK(0 <= idx && idx < output_shape.Dims(j));
1052 to_pos += idx * dims_to_count[j];
1053 }
1054 for (int j = 0; j < slice_size; j++) {
1055 output_data[to_pos + j] += updates_data[i * slice_size + j];
1056 }
1057 }
1058 }
1059
1060 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,SequentialTensorWriter<T> * writer)1061 inline void Slice(const tflite::SliceParams& op_params,
1062 const RuntimeShape& input_shape,
1063 const RuntimeShape& output_shape,
1064 SequentialTensorWriter<T>* writer) {
1065 const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
1066 TFLITE_DCHECK_LE(op_params.begin_count, 5);
1067 TFLITE_DCHECK_LE(op_params.size_count, 5);
1068 const int begin_count = op_params.begin_count;
1069 const int size_count = op_params.size_count;
1070 // We front-pad the begin and size vectors.
1071 std::array<int, 5> start;
1072 std::array<int, 5> stop;
1073 for (int i = 0; i < 5; ++i) {
1074 int padded_i = 5 - i;
1075 start[i] =
1076 begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
1077 stop[i] =
1078 (size_count < padded_i || op_params.size[size_count - padded_i] == -1)
1079 ? ext_shape.Dims(i)
1080 : start[i] + op_params.size[size_count - padded_i];
1081 }
1082
1083 for (int i0 = start[0]; i0 < stop[0]; ++i0) {
1084 for (int i1 = start[1]; i1 < stop[1]; ++i1) {
1085 for (int i2 = start[2]; i2 < stop[2]; ++i2) {
1086 for (int i3 = start[3]; i3 < stop[3]; ++i3) {
1087 for (int i4 = start[4]; i4 < stop[4]; ++i4) {
1088 writer->Write(Offset(ext_shape, i0, i1, i2, i3, i4));
1089 }
1090 }
1091 }
1092 }
1093 }
1094 }
1095
1096 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)1097 inline void Slice(const tflite::SliceParams& op_params,
1098 const RuntimeShape& input_shape, const T* input_data,
1099 const RuntimeShape& output_shape, T* output_data) {
1100 SequentialTensorWriter<T> writer(input_data, output_data);
1101 return Slice(op_params, input_shape, output_shape, &writer);
1102 }
1103
1104 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const TfLiteTensor * input,const RuntimeShape & output_shape,TfLiteTensor * output)1105 inline void Slice(const tflite::SliceParams& op_params,
1106 const RuntimeShape& input_shape, const TfLiteTensor* input,
1107 const RuntimeShape& output_shape, TfLiteTensor* output) {
1108 SequentialTensorWriter<T> writer(input, output);
1109 return Slice(op_params, input_shape, output_shape, &writer);
1110 }
1111
1112 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1113 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
1114 const T* input2_data, const RuntimeShape& output_shape,
1115 T* output_data) {
1116 const int flat_size = MatchingFlatSize(input1_shape, output_shape);
1117
1118 auto min_value = input2_data[0];
1119 for (int i = 0; i < flat_size; i++) {
1120 output_data[i] = input1_data[i] > min_value ? min_value : input1_data[i];
1121 }
1122 }
1123
1124 // Convenience version that allows, for example, generated-code calls to be
1125 // the same as other binary ops.
1126 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1127 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
1128 const RuntimeShape&, const T* input2_data,
1129 const RuntimeShape& output_shape, T* output_data) {
1130 // Drop shape of second input: not needed.
1131 Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
1132 }
1133
1134 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1135 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
1136 const T* input2_data, const RuntimeShape& output_shape,
1137 T* output_data) {
1138 const int flat_size = MatchingFlatSize(input1_shape, output_shape);
1139
1140 auto max_value = input2_data[0];
1141 for (int i = 0; i < flat_size; i++) {
1142 output_data[i] = input1_data[i] < max_value ? max_value : input1_data[i];
1143 }
1144 }
1145
1146 // Convenience version that allows, for example, generated-code calls to be
1147 // the same as other binary ops.
1148 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1149 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
1150 const RuntimeShape&, const T* input2_data,
1151 const RuntimeShape& output_shape, T* output_data) {
1152 // Drop shape of second input: not needed.
1153 Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
1154 }
1155
1156 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)1157 void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
1158 const T3* input2_data, const RuntimeShape& output_shape,
1159 T2* output_data) {
1160 ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
1161 std::greater<T1>());
1162 }
1163
1164 // Convenience version that allows, for example, generated-code calls to be
1165 // the same as other binary ops.
1166 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const RuntimeShape & input2_shape,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)1167 inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
1168 const RuntimeShape& input2_shape, const T3* input2_data,
1169 const RuntimeShape& output_shape, T2* output_data) {
1170 // Drop shape of second input: not needed.
1171 ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
1172 }
1173
1174 template <typename D, typename T>
Select(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)1175 void Select(const RuntimeShape& input_condition_shape,
1176 const D* input_condition_data, const RuntimeShape& input_x_shape,
1177 const T* input_x_data, const RuntimeShape& input_y_shape,
1178 const T* input_y_data, const RuntimeShape& output_shape,
1179 T* output_data) {
1180 int64_t flatsize;
1181 // Allow select operator executions on mixed scalar tensors and one element
1182 // tensors.
1183 if (input_condition_shape.FlatSize() == 1 && input_x_shape.FlatSize() == 1 &&
1184 input_y_shape.FlatSize() == 1 && output_shape.FlatSize() == 1) {
1185 flatsize = 1;
1186 } else {
1187 flatsize = MatchingFlatSize(input_condition_shape, input_x_shape,
1188 input_y_shape, output_shape);
1189 }
1190 for (int64_t i = 0; i < flatsize; ++i) {
1191 output_data[i] =
1192 input_condition_data[i] ? input_x_data[i] : input_y_data[i];
1193 }
1194 }
1195
1196 template <typename D, typename T>
RankOneSelect(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)1197 void RankOneSelect(const RuntimeShape& input_condition_shape,
1198 const D* input_condition_data,
1199 const RuntimeShape& input_x_shape, const T* input_x_data,
1200 const RuntimeShape& input_y_shape, const T* input_y_data,
1201 const RuntimeShape& output_shape, T* output_data) {
1202 const int64_t outer_size = input_condition_shape.FlatSize();
1203 int64_t inner_size;
1204 if (input_condition_shape.DimensionsCount() == 0) {
1205 inner_size = MatchingFlatSize(input_x_shape, input_y_shape, output_shape);
1206 } else {
1207 TFLITE_DCHECK_EQ(
1208 MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
1209 outer_size);
1210 inner_size =
1211 MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
1212 }
1213
1214 int64_t offset = 0;
1215 for (int64_t i = 0; i < outer_size; i++) {
1216 const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
1217 memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
1218 offset += inner_size;
1219 }
1220 }
1221
1222 template <typename D, typename T>
BroadcastSelect4DSlow(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)1223 void BroadcastSelect4DSlow(const RuntimeShape& input_condition_shape,
1224 const D* input_condition_data,
1225 const RuntimeShape& input_x_shape,
1226 const T* input_x_data,
1227 const RuntimeShape& input_y_shape,
1228 const T* input_y_data,
1229 const RuntimeShape& output_shape, T* output_data) {
1230 TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), 4);
1231 TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), 4);
1232 TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), 4);
1233 TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
1234
1235 const RuntimeShape extended_output_shape =
1236 RuntimeShape::ExtendedShape(4, output_shape);
1237
1238 NdArrayDesc<4> desc_condition;
1239 NdArrayDesc<4> desc_x;
1240 NdArrayDesc<4> desc_y;
1241 NdArrayDescsForElementwiseBroadcast(input_condition_shape, input_x_shape,
1242 input_y_shape, &desc_condition, &desc_x,
1243 &desc_y);
1244
1245 // In Tensorflow, the dimensions are canonically named (batch_number, row,
1246 // col, channel), with extents (batches, height, width, depth), with the
1247 // trailing dimension changing most rapidly (channels has the smallest
1248 // stride, typically 1 element).
1249 //
1250 // In generated C code, we store arrays with the dimensions reversed. The
1251 // first dimension has smallest stride.
1252 //
1253 // We name our variables by their Tensorflow convention, but generate C code
1254 // nesting loops such that the innermost loop has the smallest stride for
1255 // the best cache behavior.
1256 for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
1257 for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
1258 for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
1259 for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
1260 const int condition_index =
1261 SubscriptToIndex(desc_condition, b, y, x, c);
1262 const int x_index = SubscriptToIndex(desc_x, b, y, x, c);
1263 const int y_index = SubscriptToIndex(desc_y, b, y, x, c);
1264 output_data[Offset(extended_output_shape, b, y, x, c)] =
1265 input_condition_data[condition_index] ? input_x_data[x_index]
1266 : input_y_data[y_index];
1267 }
1268 }
1269 }
1270 }
1271 }
1272
1273 template <typename D, typename T>
SelectTrueCoords(const RuntimeShape & input_condition_shape,const D * input_condition_data,T * output_data)1274 void SelectTrueCoords(const RuntimeShape& input_condition_shape,
1275 const D* input_condition_data, T* output_data) {
1276 const size_t size = input_condition_shape.FlatSize();
1277 if (size == 0) {
1278 // Dimension is zero, in which case we don't need to output.
1279 return;
1280 }
1281 const size_t cond_rank = input_condition_shape.DimensionsCount();
1282
1283 std::vector<int> dims_to_count(cond_rank, 0);
1284 int cur_flat_size = size;
1285 for (int i = 0; i < cond_rank; ++i) {
1286 dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i);
1287 cur_flat_size = dims_to_count[i];
1288 }
1289
1290 int output_index = 0;
1291 for (int i = 0; i < size; ++i) {
1292 if (input_condition_data[i]) {
1293 // Insert the coordinate of the current item (row major) into output.
1294 int flat_index = i;
1295 for (int j = 0; j < cond_rank; ++j) {
1296 int coord_j = flat_index / dims_to_count[j];
1297 output_data[output_index * cond_rank + j] = coord_j;
1298 flat_index %= dims_to_count[j];
1299 }
1300 output_index++;
1301 }
1302 }
1303 }
1304
1305 // For easy implementation, the indices is always a vector of size-4 vectors.
1306 template <typename T, typename TI>
SparseToDense(const std::vector<std::vector<TI>> & indices,const T * values,T default_value,bool value_is_scalar,const RuntimeShape & unextended_output_shape,T * output_data)1307 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
1308 const T* values, T default_value,
1309 bool value_is_scalar,
1310 const RuntimeShape& unextended_output_shape,
1311 T* output_data) {
1312 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1313 const RuntimeShape output_shape =
1314 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1315 const int value_count = indices.size();
1316
1317 // First fill the output_data with default value.
1318 const int num_elements = output_shape.FlatSize();
1319 for (int i = 0; i < num_elements; ++i) {
1320 output_data[i] = default_value;
1321 }
1322
1323 // Special handle for value is scalar case to avoid checking the boolean
1324 // condition within the loop every time.
1325 if (value_is_scalar) {
1326 for (int i = 0; i < value_count; ++i) {
1327 const std::vector<TI>& index = indices[i];
1328 TFLITE_DCHECK_EQ(index.size(), 4);
1329 const T value = *values; // just use the first value.
1330 output_data[Offset(output_shape, index[0], index[1], index[2],
1331 index[3])] = value;
1332 }
1333 return;
1334 }
1335
1336 // Go through the values and indices to fill the sparse values.
1337 for (int i = 0; i < value_count; ++i) {
1338 const std::vector<TI>& index = indices[i];
1339 TFLITE_DCHECK_EQ(index.size(), 4);
1340 const T value = values[i];
1341 output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] =
1342 value;
1343 }
1344 }
1345
1346 template <typename T>
Pow(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1347 inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
1348 const RuntimeShape& input2_shape, const T* input2_data,
1349 const RuntimeShape& output_shape, T* output_data) {
1350 const int flat_size =
1351 MatchingFlatSize(input1_shape, input2_shape, output_shape);
1352 for (int i = 0; i < flat_size; ++i) {
1353 output_data[i] = std::pow(input1_data[i], input2_data[i]);
1354 }
1355 }
1356
1357 template <typename T>
BroadcastPow4DSlow(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)1358 inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape,
1359 const T* input1_data,
1360 const RuntimeShape& unextended_input2_shape,
1361 const T* input2_data,
1362 const RuntimeShape& unextended_output_shape,
1363 T* output_data) {
1364 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
1365 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
1366 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1367 const RuntimeShape output_shape =
1368 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1369
1370 NdArrayDesc<4> desc1;
1371 NdArrayDesc<4> desc2;
1372 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
1373 unextended_input2_shape, &desc1, &desc2);
1374
1375 for (int b = 0; b < output_shape.Dims(0); ++b) {
1376 for (int y = 0; y < output_shape.Dims(1); ++y) {
1377 for (int x = 0; x < output_shape.Dims(2); ++x) {
1378 for (int c = 0; c < output_shape.Dims(3); ++c) {
1379 auto out_idx = Offset(output_shape, b, y, x, c);
1380 auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
1381 auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
1382 auto in1_val = input1_data[in1_idx];
1383 auto in2_val = input2_data[in2_idx];
1384 output_data[out_idx] = std::pow(in1_val, in2_val);
1385 }
1386 }
1387 }
1388 }
1389 }
1390
1391 template <typename Scalar>
Reverse(int axis,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * output_data)1392 void Reverse(int axis, const RuntimeShape& input_shape,
1393 const Scalar* input_data, const RuntimeShape& output_shape,
1394 Scalar* output_data) {
1395 ruy::profiler::ScopeLabel label("Reverse");
1396
1397 int outer_size = 1;
1398 for (int i = 0; i < axis; ++i) {
1399 outer_size *= input_shape.Dims(i);
1400 }
1401
1402 int copy_size = 1;
1403 for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) {
1404 copy_size *= input_shape.Dims(i);
1405 }
1406
1407 const int dims_at_axis = input_shape.Dims(axis);
1408 for (int i = 0; i < outer_size; ++i) {
1409 for (int j = 0; j < dims_at_axis; ++j) {
1410 const int start_pos = (i * dims_at_axis + j) * copy_size;
1411 Scalar* output_ptr = output_data + start_pos;
1412 int loc = (i * dims_at_axis + dims_at_axis - j - 1) * copy_size;
1413 memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
1414 }
1415 }
1416 }
1417
1418 template <typename Scalar, typename TS>
ReverseSequence(const TS * seq_lengths,const int seq_dim,const int batch_dim,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * output_data)1419 void ReverseSequence(const TS* seq_lengths, const int seq_dim,
1420 const int batch_dim, const RuntimeShape& input_shape,
1421 const Scalar* input_data, const RuntimeShape& output_shape,
1422 Scalar* output_data) {
1423 ruy::profiler::ScopeLabel label("ReverseSequence");
1424
1425 int outer_size = 1;
1426 int outer_dim = std::min(batch_dim, seq_dim);
1427 int medium_dim = std::max(batch_dim, seq_dim);
1428 for (int i = 0; i < outer_dim; ++i) {
1429 outer_size *= input_shape.Dims(i);
1430 }
1431
1432 int medium_size = 1;
1433 for (int i = outer_dim + 1; i < medium_dim; ++i) {
1434 medium_size *= input_shape.Dims(i);
1435 }
1436
1437 int copy_size = 1;
1438 for (int i = medium_dim + 1; i < input_shape.DimensionsCount(); ++i) {
1439 copy_size *= input_shape.Dims(i);
1440 }
1441
1442 const int dims_at_outer_dim = input_shape.Dims(outer_dim);
1443 const int dims_at_medium_dim = input_shape.Dims(medium_dim);
1444
1445 Scalar* output_ptr;
1446 if (batch_dim > seq_dim) {
1447 for (int i = 0; i < outer_size; ++i) {
1448 for (int j = 0; j < dims_at_outer_dim; ++j) {
1449 const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1450 for (int p = 0; p < medium_size; ++p) {
1451 for (int q = 0; q < dims_at_medium_dim; ++q) {
1452 const int in_pos =
1453 ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1454 const Scalar* in_ptr = input_data + in_pos;
1455 int sl = seq_lengths[q] - 1;
1456 if (j > sl) {
1457 output_ptr = output_data + in_pos;
1458 } else {
1459 const int out_pos_base =
1460 (i * dims_at_outer_dim + sl - j) * medium_size;
1461 const int out_pos =
1462 ((out_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1463 output_ptr = output_data + out_pos;
1464 }
1465 memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
1466 }
1467 }
1468 }
1469 }
1470 } else if (batch_dim < seq_dim) {
1471 for (int i = 0; i < outer_size; ++i) {
1472 for (int j = 0; j < dims_at_outer_dim; ++j) {
1473 const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1474 int sl = seq_lengths[j] - 1;
1475 const int out_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1476 for (int p = 0; p < medium_size; ++p) {
1477 for (int q = 0; q < dims_at_medium_dim; ++q) {
1478 const int in_pos =
1479 ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1480 const Scalar* in_ptr = input_data + in_pos;
1481 if (q > sl) {
1482 output_ptr = output_data + in_pos;
1483 } else {
1484 const int out_pos =
1485 ((out_pos_base + p) * dims_at_medium_dim + sl - q) *
1486 copy_size;
1487 output_ptr = output_data + out_pos;
1488 }
1489 memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
1490 }
1491 }
1492 }
1493 }
1494 }
1495 }
1496
1497 template <typename T>
SegmentSum(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & segment_ids_shape,const int32_t * segment_ids_data,const RuntimeShape & output_shape,T * output_data)1498 inline void SegmentSum(const RuntimeShape& input_shape, const T* input_data,
1499 const RuntimeShape& segment_ids_shape,
1500 const int32_t* segment_ids_data,
1501 const RuntimeShape& output_shape, T* output_data) {
1502 const int segment_flat_size =
1503 MatchingFlatSizeSkipDim(input_shape, 0, output_shape);
1504
1505 memset(output_data, 0, sizeof(T) * output_shape.FlatSize());
1506
1507 for (int i = 0; i < input_shape.Dims(0); i++) {
1508 int output_index = segment_ids_data[i];
1509 for (int j = 0; j < segment_flat_size; ++j) {
1510 output_data[output_index * segment_flat_size + j] +=
1511 input_data[i * segment_flat_size + j];
1512 }
1513 }
1514 }
1515
1516 } // namespace reference_ops
1517 } // namespace tflite
1518
1519 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
1520