• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
17 
18 #include <stdint.h>
19 #include <sys/types.h>
20 
21 #include <algorithm>
22 #include <array>
23 #include <cmath>
24 #include <cstring>
25 #include <functional>
26 #include <limits>
27 #include <memory>
28 #include <type_traits>
29 
30 #include "Eigen/Core"
31 #include "fixedpoint/fixedpoint.h"
32 #include "ruy/profiler/instrumentation.h"  // from @ruy
33 #include "tensorflow/lite/c/common.h"
34 #include "tensorflow/lite/kernels/internal/common.h"
35 #include "tensorflow/lite/kernels/internal/quantization_util.h"
36 #include "tensorflow/lite/kernels/internal/reference/add.h"
37 #include "tensorflow/lite/kernels/internal/reference/add_n.h"
38 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
39 #include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
40 #include "tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h"
41 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
42 #include "tensorflow/lite/kernels/internal/reference/cast.h"
43 #include "tensorflow/lite/kernels/internal/reference/ceil.h"
44 #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
45 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
46 #include "tensorflow/lite/kernels/internal/reference/conv.h"
47 #include "tensorflow/lite/kernels/internal/reference/depth_to_space.h"
48 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
49 #include "tensorflow/lite/kernels/internal/reference/div.h"
50 #include "tensorflow/lite/kernels/internal/reference/elu.h"
51 #include "tensorflow/lite/kernels/internal/reference/exp.h"
52 #include "tensorflow/lite/kernels/internal/reference/fill.h"
53 #include "tensorflow/lite/kernels/internal/reference/floor.h"
54 #include "tensorflow/lite/kernels/internal/reference/floor_div.h"
55 #include "tensorflow/lite/kernels/internal/reference/floor_mod.h"
56 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
57 #include "tensorflow/lite/kernels/internal/reference/gather.h"
58 #include "tensorflow/lite/kernels/internal/reference/hard_swish.h"
59 #include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
60 #include "tensorflow/lite/kernels/internal/reference/leaky_relu.h"
61 #include "tensorflow/lite/kernels/internal/reference/log_softmax.h"
62 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
63 #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
64 #include "tensorflow/lite/kernels/internal/reference/mul.h"
65 #include "tensorflow/lite/kernels/internal/reference/neg.h"
66 #include "tensorflow/lite/kernels/internal/reference/pad.h"
67 #include "tensorflow/lite/kernels/internal/reference/pooling.h"
68 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
69 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
70 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
71 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
72 #include "tensorflow/lite/kernels/internal/reference/requantize.h"
73 #include "tensorflow/lite/kernels/internal/reference/resize_bilinear.h"
74 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
75 #include "tensorflow/lite/kernels/internal/reference/round.h"
76 #include "tensorflow/lite/kernels/internal/reference/softmax.h"
77 #include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"
78 #include "tensorflow/lite/kernels/internal/reference/space_to_depth.h"
79 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
80 #include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
81 #include "tensorflow/lite/kernels/internal/reference/sub.h"
82 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
83 #include "tensorflow/lite/kernels/internal/reference/transpose.h"
84 #include "tensorflow/lite/kernels/internal/reference/transpose_conv.h"
85 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
86 #include "tensorflow/lite/kernels/internal/tensor.h"
87 #include "tensorflow/lite/kernels/internal/types.h"
88 namespace tflite {
89 
90 namespace reference_ops {
91 
92 template <typename T>
Relu(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)93 inline void Relu(const RuntimeShape& input_shape, const T* input_data,
94                  const RuntimeShape& output_shape, T* output_data) {
95   const int flat_size = MatchingFlatSize(input_shape, output_shape);
96   for (int i = 0; i < flat_size; ++i) {
97     const T val = input_data[i];
98     const T lower = 0;
99     const T clamped = val < lower ? lower : val;
100     output_data[i] = clamped;
101   }
102 }
103 
104 template <typename T>
Relu1(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)105 inline void Relu1(const RuntimeShape& input_shape, const T* input_data,
106                   const RuntimeShape& output_shape, T* output_data) {
107   ruy::profiler::ScopeLabel label("Relu1 (not fused)");
108   const int flat_size = MatchingFlatSize(input_shape, output_shape);
109   for (int i = 0; i < flat_size; ++i) {
110     const T val = input_data[i];
111     const T upper = 1;
112     const T lower = -1;
113     const T clamped = val > upper ? upper : val < lower ? lower : val;
114     output_data[i] = clamped;
115   }
116 }
117 
Relu6(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)118 inline void Relu6(const RuntimeShape& input_shape, const float* input_data,
119                   const RuntimeShape& output_shape, float* output_data) {
120   ruy::profiler::ScopeLabel label("Relu6 (not fused)");
121   const int flat_size = MatchingFlatSize(input_shape, output_shape);
122   for (int i = 0; i < flat_size; ++i) {
123     const float val = input_data[i];
124     const float upper = 6;
125     const float lower = 0;
126     const float clamped = val > upper ? upper : val < lower ? lower : val;
127     output_data[i] = clamped;
128   }
129 }
130 
131 template <typename T>
ReluX(const tflite::ReluParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)132 inline void ReluX(const tflite::ReluParams& params,
133                   const RuntimeShape& input_shape, const T* input_data,
134                   const RuntimeShape& output_shape, T* output_data) {
135   ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
136   const int flat_size = MatchingFlatSize(input_shape, output_shape);
137   for (int i = 0; i < flat_size; ++i) {
138     const int32 val = static_cast<int32_t>(input_data[i]);
139     int32 clamped = params.output_offset +
140                     MultiplyByQuantizedMultiplier(val - params.input_offset,
141                                                   params.output_multiplier,
142                                                   params.output_shift);
143     clamped = std::max(params.quantized_activation_min, clamped);
144     clamped = std::min(params.quantized_activation_max, clamped);
145     output_data[i] = static_cast<T>(clamped);
146   }
147 }
148 
149 template <typename T>
ReluX(const tflite::ActivationParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)150 inline void ReluX(const tflite::ActivationParams& params,
151                   const RuntimeShape& input_shape, const T* input_data,
152                   const RuntimeShape& output_shape, T* output_data) {
153   ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
154   const int flat_size = MatchingFlatSize(input_shape, output_shape);
155   const T max_value = params.quantized_activation_max;
156   const T min_value = params.quantized_activation_min;
157   for (int i = 0; i < flat_size; ++i) {
158     const T val = input_data[i];
159     const T clamped = val > max_value   ? max_value
160                       : val < min_value ? min_value
161                                         : val;
162     output_data[i] = clamped;
163   }
164 }
165 
166 // TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
167 // dimensionality if the runtime code does a single loop over one dimension
168 // that handles broadcasting as the base case. The code generator would then
169 // generate max(D1, D2) nested for loops.
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)170 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
171                                  const RuntimeShape& unswitched_input1_shape,
172                                  const uint8* unswitched_input1_data,
173                                  const RuntimeShape& unswitched_input2_shape,
174                                  const uint8* unswitched_input2_data,
175                                  const RuntimeShape& output_shape,
176                                  uint8* output_data) {
177   ArithmeticParams switched_params = unswitched_params;
178   switched_params.input1_offset = unswitched_params.input2_offset;
179   switched_params.input2_offset = unswitched_params.input1_offset;
180 
181   const bool use_unswitched =
182       unswitched_params.broadcast_category ==
183       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
184 
185   const ArithmeticParams& params =
186       use_unswitched ? unswitched_params : switched_params;
187   const uint8* input1_data =
188       use_unswitched ? unswitched_input1_data : unswitched_input2_data;
189   const uint8* input2_data =
190       use_unswitched ? unswitched_input2_data : unswitched_input1_data;
191 
192   // Fivefold nested loops. The second input resets its position for each
193   // iteration of the second loop. The first input resets its position at the
194   // beginning of the fourth loop. The innermost loop is an elementwise Mul of
195   // sections of the arrays.
196   uint8* output_data_ptr = output_data;
197   const uint8* input1_data_ptr = input1_data;
198   const uint8* input2_data_reset = input2_data;
199   int y0 = params.broadcast_shape[0];
200   int y1 = params.broadcast_shape[1];
201   int y2 = params.broadcast_shape[2];
202   int y3 = params.broadcast_shape[3];
203   int y4 = params.broadcast_shape[4];
204   for (int i0 = 0; i0 < y0; ++i0) {
205     const uint8* input2_data_ptr;
206     for (int i1 = 0; i1 < y1; ++i1) {
207       input2_data_ptr = input2_data_reset;
208       for (int i2 = 0; i2 < y2; ++i2) {
209         for (int i3 = 0; i3 < y3; ++i3) {
210           MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
211                          output_data_ptr);
212           input2_data_ptr += y4;
213           output_data_ptr += y4;
214         }
215         input1_data_ptr += y4;
216       }
217     }
218     input2_data_reset = input2_data_ptr;
219   }
220 }
221 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)222 inline void Mul(const ArithmeticParams& params,
223                 const RuntimeShape& input1_shape, const int16* input1_data,
224                 const RuntimeShape& input2_shape, const int16* input2_data,
225                 const RuntimeShape& output_shape, int16* output_data) {
226   ruy::profiler::ScopeLabel label("Mul/Int16");
227 
228   const int flat_size =
229       MatchingElementsSize(input1_shape, input2_shape, output_shape);
230 
231   for (int i = 0; i < flat_size; i++) {
232     // F0 uses 0 integer bits, range [-1, 1].
233     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
234 
235     F0 unclamped_result =
236         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
237     output_data[i] = unclamped_result.raw();
238   }
239 }
240 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)241 inline void Mul(const ArithmeticParams& params,
242                 const RuntimeShape& input1_shape, const int16* input1_data,
243                 const RuntimeShape& input2_shape, const int16* input2_data,
244                 const RuntimeShape& output_shape, uint8* output_data) {
245   ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
246   int32 output_offset = params.output_offset;
247   int32 output_activation_min = params.quantized_activation_min;
248   int32 output_activation_max = params.quantized_activation_max;
249   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
250 
251   const int flat_size =
252       MatchingElementsSize(input1_shape, input2_shape, output_shape);
253 
254   for (int i = 0; i < flat_size; i++) {
255     // F0 uses 0 integer bits, range [-1, 1].
256     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
257 
258     F0 unclamped_result =
259         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
260     int16 rescaled_result =
261         gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
262     int16 clamped_result =
263         std::min<int16>(output_activation_max - output_offset, rescaled_result);
264     clamped_result =
265         std::max<int16>(output_activation_min - output_offset, clamped_result);
266     output_data[i] = output_offset + clamped_result;
267   }
268 }
269 
Sub16(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16_t * input1_data,const RuntimeShape & input2_shape,const int16_t * input2_data,const RuntimeShape & output_shape,int16_t * output_data)270 inline void Sub16(const ArithmeticParams& params,
271                   const RuntimeShape& input1_shape, const int16_t* input1_data,
272                   const RuntimeShape& input2_shape, const int16_t* input2_data,
273                   const RuntimeShape& output_shape, int16_t* output_data) {
274   ruy::profiler::ScopeLabel label("Sub/Int16");
275   const int input1_shift = params.input1_shift;
276   const int flat_size =
277       MatchingElementsSize(input1_shape, input2_shape, output_shape);
278   const int16 output_activation_min = params.quantized_activation_min;
279   const int16 output_activation_max = params.quantized_activation_max;
280 
281   TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
282   TFLITE_DCHECK_LE(input1_shift, 0);
283   TFLITE_DCHECK_LE(params.input2_shift, 0);
284   const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
285   const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
286   const int input_right_shift =
287       input1_shift == 0 ? -params.input2_shift : -input1_shift;
288 
289   if (input1_shift == 0) {
290     // F0 uses 0 integer bits, range [-1, 1].
291     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
292     for (int i = 0; i < flat_size; ++i) {
293       F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
294       F0 scaled_input = F0::FromRaw(
295           gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
296       F0 result = SaturatingSub(input_ready_scaled, scaled_input);
297       const int16 raw_output = result.raw();
298       const int16 clamped_output = std::min(
299           output_activation_max, std::max(output_activation_min, raw_output));
300       output_data[i] = clamped_output;
301     }
302   } else {
303     // F0 uses 0 integer bits, range [-1, 1].
304     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
305     for (int i = 0; i < flat_size; ++i) {
306       F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
307       F0 scaled_input = F0::FromRaw(
308           gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
309       F0 result = SaturatingSub(scaled_input, input_ready_scaled);
310       const int16 raw_output = result.raw();
311       const int16 clamped_output = std::min(
312           output_activation_max, std::max(output_activation_min, raw_output));
313       output_data[i] = clamped_output;
314     }
315   }
316 }
317 
318 template <typename Scalar>
Pack(const PackParams & params,const RuntimeShape * const * input_shapes,const Scalar * const * input_data,const RuntimeShape & output_shape,Scalar * output_data)319 void Pack(const PackParams& params, const RuntimeShape* const* input_shapes,
320           const Scalar* const* input_data, const RuntimeShape& output_shape,
321           Scalar* output_data) {
322   ruy::profiler::ScopeLabel label("Pack");
323   const int dimensions = output_shape.DimensionsCount();
324   int axis = params.axis;
325   int inputs_count = params.inputs_count;
326 
327   int outer_size = 1;
328   for (int i = 0; i < axis; i++) {
329     outer_size *= output_shape.Dims(i);
330   }
331   int copy_size = 1;
332   for (int i = params.axis + 1; i < dimensions; i++) {
333     copy_size *= output_shape.Dims(i);
334   }
335   TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
336 
337   for (int i = 0; i < inputs_count; ++i) {
338     for (int k = 0; k < outer_size; k++) {
339       const Scalar* input_ptr = input_data[i] + copy_size * k;
340       int loc = k * inputs_count * copy_size + i * copy_size;
341       memcpy(output_data + loc, input_ptr, copy_size * sizeof(Scalar));
342     }
343   }
344 }
345 
346 template <typename Scalar>
Unpack(const UnpackParams & params,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * const * output_datas)347 void Unpack(const UnpackParams& params, const RuntimeShape& input_shape,
348             const Scalar* input_data, const RuntimeShape& output_shape,
349             Scalar* const* output_datas) {
350   ruy::profiler::ScopeLabel label("Unpack");
351   const int dimensions = input_shape.DimensionsCount();
352   const int outputs_count = params.num_split;
353 
354   int outer_size = 1;
355   int axis = params.axis;
356   if (axis < 0) {
357     axis += dimensions;
358   }
359   TFLITE_DCHECK_GE(axis, 0);
360   TFLITE_DCHECK_LT(axis, dimensions);
361   for (int i = 0; i < axis; ++i) {
362     outer_size *= input_shape.Dims(i);
363   }
364   int copy_size = 1;
365   for (int i = axis + 1; i < dimensions; ++i) {
366     copy_size *= input_shape.Dims(i);
367   }
368   TFLITE_DCHECK_EQ(output_shape.FlatSize(), copy_size * outer_size);
369 
370   for (int i = 0; i < outputs_count; ++i) {
371     for (int k = 0; k < outer_size; k++) {
372       Scalar* output_ptr = output_datas[i] + copy_size * k;
373       int loc = k * outputs_count * copy_size + i * copy_size;
374       memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
375     }
376   }
377 }
378 
379 template <typename Scalar>
PackWithScaling(const PackParams & params,const RuntimeShape * const * input_shapes,const uint8 * const * input_data,const RuntimeShape & output_shape,uint8 * output_data)380 void PackWithScaling(const PackParams& params,
381                      const RuntimeShape* const* input_shapes,
382                      const uint8* const* input_data,
383                      const RuntimeShape& output_shape, uint8* output_data) {
384   ruy::profiler::ScopeLabel label("PackWithScaling");
385   const int dimensions = output_shape.DimensionsCount();
386   int axis = params.axis;
387   const int32* input_zeropoint = params.input_zeropoint;
388   const float* input_scale = params.input_scale;
389   int inputs_count = params.inputs_count;
390   const int32 output_zeropoint = params.output_zeropoint;
391   const float output_scale = params.output_scale;
392 
393   int outer_size = 1;
394   for (int i = 0; i < axis; i++) {
395     outer_size *= output_shape.Dims(i);
396   }
397   int copy_size = 1;
398   for (int i = axis + 1; i < dimensions; i++) {
399     copy_size *= output_shape.Dims(i);
400   }
401   TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
402 
403   Scalar* output_ptr = output_data;
404   const float inverse_output_scale = 1.f / output_scale;
405   for (int k = 0; k < outer_size; k++) {
406     for (int i = 0; i < inputs_count; ++i) {
407       if (input_zeropoint[i] == output_zeropoint &&
408           input_scale[i] == output_scale) {
409         memcpy(output_ptr, input_data[i] + k * copy_size,
410                copy_size * sizeof(Scalar));
411       } else {
412         assert(false);
413         const float scale = input_scale[i] * inverse_output_scale;
414         const float bias = -input_zeropoint[i] * scale;
415         auto input_ptr = input_data[i];
416         for (int j = 0; j < copy_size; ++j) {
417           const int32_t value =
418               static_cast<int32_t>(std::round(input_ptr[j] * scale + bias)) +
419               output_zeropoint;
420           output_ptr[j] =
421               static_cast<uint8_t>(std::max(std::min(255, value), 0));
422         }
423       }
424       output_ptr += copy_size;
425     }
426   }
427 }
428 
429 template <typename Scalar>
DepthConcatenation(const ConcatenationParams & params,const RuntimeShape * const * input_shapes,const Scalar * const * input_data,const RuntimeShape & output_shape,Scalar * output_data)430 void DepthConcatenation(const ConcatenationParams& params,
431                         const RuntimeShape* const* input_shapes,
432                         const Scalar* const* input_data,
433                         const RuntimeShape& output_shape, Scalar* output_data) {
434   ruy::profiler::ScopeLabel label("DepthConcatenation");
435   auto params_copy = params;
436   params_copy.axis = 3;
437   Concatenation(params_copy, input_shapes, input_data, output_shape,
438                 output_data);
439 }
440 
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data)441 inline void LstmCell(
442     const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
443     const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
444     const float* prev_activ_data, const RuntimeShape& weights_shape,
445     const float* weights_data, const RuntimeShape& unextended_bias_shape,
446     const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
447     const float* prev_state_data,
448     const RuntimeShape& unextended_output_state_shape, float* output_state_data,
449     const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
450     const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
451     const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
452   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
453   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
454   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
455   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
456   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
457   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
458   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
459   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
460   const RuntimeShape input_shape =
461       RuntimeShape::ExtendedShape(4, unextended_input_shape);
462   const RuntimeShape prev_activ_shape =
463       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
464   const RuntimeShape bias_shape =
465       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
466   const RuntimeShape prev_state_shape =
467       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
468   const RuntimeShape output_state_shape =
469       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
470   const RuntimeShape output_activ_shape =
471       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
472   const RuntimeShape concat_temp_shape =
473       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
474   const RuntimeShape activ_temp_shape =
475       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
476   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
477 
478   const int weights_dim_count = weights_shape.DimensionsCount();
479   const int batches =
480       MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
481                   output_state_shape, 0, output_activ_shape, 0);
482   const int height =
483       MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
484                   output_state_shape, 1, output_activ_shape, 1);
485   const int width =
486       MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
487                   output_state_shape, 2, output_activ_shape, 2);
488   const int input_depth = input_shape.Dims(3);
489   const int prev_activ_depth = prev_activ_shape.Dims(3);
490   const int total_input_depth = prev_activ_depth + input_depth;
491   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
492                    total_input_depth);
493   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
494   const int intern_activ_depth =
495       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
496   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
497                    intern_activ_depth * total_input_depth);
498   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
499   const int output_depth =
500       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
501                   3, output_activ_shape, 3);
502   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
503 
504   // Concatenate prev_activ and input data together
505   std::vector<float const*> concat_input_arrays_data;
506   std::vector<RuntimeShape const*> concat_input_arrays_shapes;
507   concat_input_arrays_data.push_back(input_data);
508   concat_input_arrays_data.push_back(prev_activ_data);
509   concat_input_arrays_shapes.push_back(&input_shape);
510   concat_input_arrays_shapes.push_back(&prev_activ_shape);
511   tflite::ConcatenationParams concat_params;
512   concat_params.axis = 3;
513   concat_params.inputs_count = concat_input_arrays_data.size();
514   Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
515                 &(concat_input_arrays_data[0]), concat_temp_shape,
516                 concat_temp_data);
517 
518   // Fully connected
519   tflite::FullyConnectedParams fc_params;
520   fc_params.float_activation_min = std::numeric_limits<float>::lowest();
521   fc_params.float_activation_max = std::numeric_limits<float>::max();
522   FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
523                  weights_data, bias_shape, bias_data, activ_temp_shape,
524                  activ_temp_data);
525 
526   // Memory state update (the LSTM "guts")
527   for (int b = 0; b < batches; ++b) {
528     for (int w = 0; w < width; ++w) {
529       for (int h = 0; h < height; ++h) {
530         for (int c = 0; c < output_depth; ++c) {
531           const float input_gate =
532               1.f /
533               (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
534                                                       0 * output_depth + c)]));
535           const float new_input = std::tanh(activ_temp_data[Offset(
536               activ_temp_shape, b, h, w, 1 * output_depth + c)]);
537           const float forget_gate =
538               1.f /
539               (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
540                                                       2 * output_depth + c)]));
541           const float output_gate =
542               1.f /
543               (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
544                                                       3 * output_depth + c)]));
545           const float new_state =
546               input_gate * new_input +
547               forget_gate *
548                   prev_state_data[Offset(prev_state_shape, b, h, w, c)];
549           output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
550           output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
551               output_gate * std::tanh(new_state);
552         }
553       }
554     }
555   }
556 }
557 
558 // Quantized LSTM cell implementation.
559 // The quantization of the input, output arrays is as follows:
560 //  - The input activations are quantized as uint8 on the interval
561 //    [-1, 127/128].
562 //    The rationale for that is that is the natural interval for output
563 //    activations (see next point) and these need to be concatenated together.
564 //    We could accommodate different ranges by re-scaling, but we empirically
565 //    found that setting the input activations range to be [-1, 127/128] in the
566 //    first place, removing the need for re-scaling, greatly improves accuracy.
567 //  - The output activations are quantized as uint8 on the interval
568 //    [-1, 127/128].
569 //    The rationale for that is that the definition of a LSTM cell makes them
570 //    intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128]
571 //    makes for simpler, more accurate fixed-point arithmetic.
572 //  - The output-at-previous-timestep state array is obviously quantized as
573 //    the output activations.
574 //  - The internal LSTM memory (not the output-at-previous-timestep, the other
575 //    internal state array) is int16-quantized and may use any power-of-two,
576 //    symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call
577 //    StateIntegerBits below, see the below discussion of that template
578 //    parameter ("The StateIntegerBits template parameter").
579 //  - The output of the internal fully-connected node is int16-quantized
580 //    on the interval [-8, 8 * 32767/32768], the rationale for which is
581 //    explained just below ("Why [-8, 8] for fully-connected output?").
582 //
583 //
584 // === The StateIntegerBits template parameter ===
585 //
586 // The StateIntegerBits template parameter controls the fixed-point format used
587 // to represent the internal memory of the LSTM cell (not the
588 // output-at-previous-timestep, the other internal state array). It's currently
589 // a template parameter so that the model can control that. The most typical
590 // value for StateIntegerBits is 4. Other plausible values are anywhere between
591 // 3 and 5. We might eventually standardize on a single supported value, e.g. 4,
592 // and drop that template parameter. The reason why it can't be a runtime
593 // parameter is that this controls the fixed-point format used, i.e. we need to
594 // generate actually different code based on it. In particular, we generate code
595 // for a fixed-point tanh() implementation for that format, which internally
596 // uses a fixed-point exp() implementation, which internally uses a
597 // barrel-shifter with a number of steps that depends on StateIntegerBits.
598 // Another consequence of that is that a higher value of StateIntegerBits
599 // results in a more expensive implementation (more barrel shifter steps
600 // needed).
601 //
602 //
603 // === Why [-8, 8] for fully-connected output? ===
604 //
605 // This array is only fed to Logistic and Tanh functions, for which
606 // the quantized implementation will want to use fixed-point arithmetic,
607 // requiring a power-of-two representation interval. Thus, we should right
608 // away quantize this array to a power-of-two interval; otherwise,
609 // implementation will need to rescale that, losing any benefit that a tighter
610 // representation interval might otherwise yield, while introducing some
611 // numerical error and computational overhead.
612 //
613 // Now, Logistic and Tanh
614 // are nearly constant (nearly equal to their horizontal asymptotes)
615 // outside of a small bounded interval around 0:
616 //
617 //   Logistic(4) = 1 - 1.8e-2     Tanh(4) = 1 - 6.7e-4
618 //   Logistic(8) = 1 - 3.4e-4     Tanh(8) = 1 - 2.3e-7
619 //   Logistic(16) = 1 - 1.1e-7    Tanh(16) = 1 - 2.5e-14
620 //
621 // From this, we see that clamping to [-4, 4] would be too inaccurate
622 // (the error of 1.8e-2 on Logistic would be felt even in 8bit precision)
623 // while clamping to [-16, 16] would make no difference even in float32.
624 // However, for a fixed-point implementation in 16-bit integers, using 5
625 // integer bits to represent the [-16, 16] range would leave only 11
626 // fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive
627 // representable values. Notice that is higher than the
628 // worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic.
629 // Using [-8, 8] thus seems like the better compromise overall, enjoying
630 // an increment of 2.4e-4 between representable values and a worst-case
631 // clamping error of 3.4e-4, both better than the increment of 4.9e-4 with
632 // [-16, 16].
633 //
634 // Moreover, all other things being equal, it is nice to choose the narrower
635 // representation range, as that makes the implementation of fixed-point
636 // math functions a little cheaper (each integer bit requires an additional
637 // barrel-shifter atep in the implementation of exp(-x)). That is further
638 // reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make
639 // sense for 32-bit float or 32-bit fixed-point quantization, but we are
640 // aiming for 16-bit fixed-point quantization of these internal nodes here.
641 //
642 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,void * gemmlowp_context)643 inline void LstmCell(const LstmCellParams& params,
644                      const RuntimeShape& unextended_input_shape,
645                      const uint8* input_data_uint8,
646                      const RuntimeShape& unextended_prev_activ_shape,
647                      const uint8* prev_activ_data_uint8,
648                      const RuntimeShape& weights_shape,
649                      const uint8* weights_data_uint8,
650                      const RuntimeShape& unextended_bias_shape,
651                      const int32* bias_data_int32,
652                      const RuntimeShape& unextended_prev_state_shape,
653                      const int16* prev_state_data_int16,
654                      const RuntimeShape& unextended_output_state_shape,
655                      int16* output_state_data_int16,
656                      const RuntimeShape& unextended_output_activ_shape,
657                      uint8* output_activ_data_uint8,
658                      const RuntimeShape& unextended_concat_temp_shape,
659                      uint8* concat_temp_data_uint8,
660                      const RuntimeShape& unextended_activ_temp_shape,
661                      int16* activ_temp_data_int16, void* gemmlowp_context) {
662   (void)gemmlowp_context;  // only used in optimized code.
663   int32 weights_zero_point = params.weights_zero_point;
664   int32 accum_multiplier = params.accum_multiplier;
665   int accum_shift = params.accum_shift;
666   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
667   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
668   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
669   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
670   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
671   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
672   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
673   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
674   const RuntimeShape input_shape =
675       RuntimeShape::ExtendedShape(4, unextended_input_shape);
676   const RuntimeShape prev_activ_shape =
677       RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
678   const RuntimeShape bias_shape =
679       RuntimeShape::ExtendedShape(4, unextended_bias_shape);
680   const RuntimeShape prev_state_shape =
681       RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
682   const RuntimeShape output_state_shape =
683       RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
684   const RuntimeShape output_activ_shape =
685       RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
686   const RuntimeShape concat_temp_shape =
687       RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
688   const RuntimeShape activ_temp_shape =
689       RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
690   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
691 
692   // Gather dimensions information, and perform consistency checks.
693   const int weights_dim_count = weights_shape.DimensionsCount();
694   const int outer_size = MatchingFlatSizeSkipDim(
695       input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
696       output_activ_shape);
697   const int input_depth = input_shape.Dims(3);
698   const int prev_activ_depth = prev_activ_shape.Dims(3);
699   const int total_input_depth = prev_activ_depth + input_depth;
700   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
701                    total_input_depth);
702   const int intern_activ_depth =
703       MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
704   TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
705                    intern_activ_depth * total_input_depth);
706   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
707   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
708   const int output_depth =
709       MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
710                   3, output_activ_shape, 3);
711   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
712   const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
713   const int fc_output_depth =
714       MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
715   const int fc_accum_depth = total_input_depth;
716   TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
717 
718   // Depth-concatenate prev_activ and input data together.
719   uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
720                                               prev_activ_data_uint8};
721   const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
722                                                        &prev_activ_shape};
723   tflite::ConcatenationParams concat_params;
724   concat_params.axis = 3;
725   concat_params.inputs_count = 2;
726   Concatenation(concat_params, concat_input_arrays_shapes,
727                 concat_input_arrays_data, concat_temp_shape,
728                 concat_temp_data_uint8);
729 
730   // Implementation of the fully connected node inside the LSTM cell.
731   // The operands are 8-bit integers, the accumulators are internally 32bit
732   // integers, and the output is 16-bit fixed-point with 3 integer bits so
733   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
734   // is explained in the function comment above.
735   for (int b = 0; b < fc_batches; ++b) {
736     for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
737       // Internal accumulation.
738       // Initialize accumulator with the bias-value.
739       int32 accum = bias_data_int32[out_c];
740       // Accumulation loop.
741       for (int d = 0; d < fc_accum_depth; ++d) {
742         int16 input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
743         int16 weights_val =
744             weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
745         accum += input_val * weights_val;
746       }
747       // Down-scale the final int32 accumulator to the scale used by our
748       // (16-bit, using 3 integer bits) fixed-point format. The quantized
749       // multiplier and shift here have been pre-computed offline
750       // (e.g. by toco).
751       accum =
752           MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
753       // Saturate, cast to int16, and store to the temporary activations array.
754       accum = std::max(-32768, std::min(32767, accum));
755       activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
756     }
757   }
758 
759   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
760   // and muls, all done in 16-bit fixed-point.
761   for (int b = 0; b < outer_size; ++b) {
762     for (int c = 0; c < output_depth; ++c) {
763       // Define the fixed-point data types that we will use here. All use
764       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
765       // They only differ by the number of integral vs. fractional bits,
766       // determining the range of values that they can represent.
767       //
768       // F0 uses 0 integer bits, range [-1, 1].
769       // This is the return type of math functions such as tanh, logistic,
770       // whose range is in [-1, 1].
771       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
772       // F3 uses 3 integer bits, range [-8, 8].
773       // This is the range of the previous fully-connected node's output,
774       // which is our input here.
775       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
776       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
777       // 2^StateIntegerBits]. It's used to represent the internal state, whose
778       // number of integer bits is currently dictated by the model. See comment
779       // on the StateIntegerBits template parameter above.
780       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
781       // Implementation of input gate, using fixed-point logistic function.
782       F3 input_gate_input = F3::FromRaw(
783           activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
784       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
785       // Implementation of input modulation gate, using fixed-point tanh
786       // function.
787       F3 input_modulation_gate_input = F3::FromRaw(
788           activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
789       F0 input_modulation_gate_output =
790           gemmlowp::tanh(input_modulation_gate_input);
791       // Implementation of forget gate, using fixed-point logistic function.
792       F3 forget_gate_input = F3::FromRaw(
793           activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
794       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
795       // Implementation of output gate, using fixed-point logistic function.
796       F3 output_gate_input = F3::FromRaw(
797           activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
798       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
799       // Implementation of internal multiplication nodes, still in fixed-point.
800       F0 input_times_input_modulation =
801           input_gate_output * input_modulation_gate_output;
802       FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]);
803       FS prev_state_times_forget_state = forget_gate_output * prev_state;
804       // Implementation of internal addition node, saturating.
805       FS new_state = gemmlowp::SaturatingAdd(
806           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
807           prev_state_times_forget_state);
808       // Implementation of last internal Tanh node, still in fixed-point.
809       // Since a Tanh fixed-point implementation is specialized for a given
810       // number or integer bits, and each specialization can have a substantial
811       // code size, and we already used above a Tanh on an input with 3 integer
812       // bits, and per the table in the above function comment there is no
813       // significant accuracy to be lost by clamping to [-8, +8] for a
814       // 3-integer-bits representation, let us just do that. This helps people
815       // porting this to targets where code footprint must be minimized.
816       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
817       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
818       // Store the new internal state back to memory, as 16-bit integers.
819       // Note: here we store the original value with StateIntegerBits, not
820       // the rescaled 3-integer-bits value fed to tanh.
821       output_state_data_int16[b * output_depth + c] = new_state.raw();
822       // Down-scale the output activations to 8-bit integers, saturating,
823       // and store back to memory.
824       int16 rescaled_output_activ =
825           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
826       int16 clamped_output_activ =
827           std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
828       output_activ_data_uint8[b * output_depth + c] =
829           128 + clamped_output_activ;
830     }
831   }
832 }
833 
834 template <typename Scalar>
Split(const SplitParams & params,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape * const * output_shapes,Scalar * const * output_data)835 void Split(const SplitParams& params, const RuntimeShape& input_shape,
836            const Scalar* input_data, const RuntimeShape* const* output_shapes,
837            Scalar* const* output_data) {
838   ruy::profiler::ScopeLabel label("Split");
839   const int split_dimensions = input_shape.DimensionsCount();
840   int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
841   int outputs_count = params.num_split;
842   TFLITE_DCHECK_LT(axis, split_dimensions);
843 
844   int64_t split_size = 0;
845   for (int i = 0; i < outputs_count; i++) {
846     TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
847     for (int j = 0; j < split_dimensions; j++) {
848       if (j != axis) {
849         MatchingDim(*output_shapes[i], j, input_shape, j);
850       }
851     }
852     split_size += output_shapes[i]->Dims(axis);
853   }
854   TFLITE_DCHECK_EQ(split_size, input_shape.Dims(axis));
855   int64_t outer_size = 1;
856   for (int i = 0; i < axis; ++i) {
857     outer_size *= input_shape.Dims(i);
858   }
859   // For all output arrays,
860   // FlatSize() = outer_size * Dims(axis) * base_inner_size;
861   int64_t base_inner_size = 1;
862   for (int i = axis + 1; i < split_dimensions; ++i) {
863     base_inner_size *= input_shape.Dims(i);
864   }
865 
866   const Scalar* input_ptr = input_data;
867   for (int k = 0; k < outer_size; k++) {
868     for (int i = 0; i < outputs_count; ++i) {
869       const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
870       memcpy(output_data[i] + k * copy_size, input_ptr,
871              copy_size * sizeof(Scalar));
872       input_ptr += copy_size;
873     }
874   }
875 }
876 
NodeOffset(int b,int h,int w,int height,int width)877 inline int NodeOffset(int b, int h, int w, int height, int width) {
878   return (b * height + h) * width + w;
879 }
880 
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)881 inline void LocalResponseNormalization(
882     const tflite::LocalResponseNormalizationParams& op_params,
883     const RuntimeShape& input_shape, const float* input_data,
884     const RuntimeShape& output_shape, float* output_data) {
885   const int trailing_dim = input_shape.DimensionsCount() - 1;
886   const int outer_size =
887       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
888   const int depth =
889       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
890 
891   for (int i = 0; i < outer_size; ++i) {
892     for (int c = 0; c < depth; ++c) {
893       const int begin_input_c = std::max(0, c - op_params.range);
894       const int end_input_c = std::min(depth, c + op_params.range);
895       float accum = 0.f;
896       for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
897         const float input_val = input_data[i * depth + input_c];
898         accum += input_val * input_val;
899       }
900       const float multiplier =
901           std::pow(op_params.bias + op_params.alpha * accum, -op_params.beta);
902       output_data[i * depth + c] = input_data[i * depth + c] * multiplier;
903     }
904   }
905 }
906 
Dequantize(const RuntimeShape & input_shape,const Eigen::half * input_data,const RuntimeShape & output_shape,float * output_data)907 inline void Dequantize(const RuntimeShape& input_shape,
908                        const Eigen::half* input_data,
909                        const RuntimeShape& output_shape, float* output_data) {
910   const int flat_size = MatchingFlatSize(input_shape, output_shape);
911   for (int i = 0; i < flat_size; i++) {
912     output_data[i] = static_cast<float>(input_data[i]);
913   }
914 }
915 
FakeQuant(const tflite::FakeQuantParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)916 inline void FakeQuant(const tflite::FakeQuantParams& op_params,
917                       const RuntimeShape& input_shape, const float* input_data,
918                       const RuntimeShape& output_shape, float* output_data) {
919   ruy::profiler::ScopeLabel label("FakeQuant");
920   float rmin = op_params.minmax.min;
921   float rmax = op_params.minmax.max;
922   int num_bits = op_params.num_bits;
923   // 0 should always be a representable value. Let's assume that the initial
924   // min,max range contains 0.
925   TFLITE_DCHECK_LE(rmin, 0.0f);
926   TFLITE_DCHECK_GE(rmax, 0.0f);
927   TFLITE_DCHECK_LT(rmin, rmax);
928 
929   // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor.
930   int quant_min = 0;
931   int quant_max = (1 << num_bits) - 1;
932   float nudged_min, nudged_max, nudged_scale;
933   NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
934                          &nudged_max, &nudged_scale);
935   const int flat_size = MatchingFlatSize(input_shape, output_shape);
936   FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data,
937                     output_data, flat_size);
938 }
939 
940 // Common subroutine for both `GatherNd` and `GatherNdString`.
941 struct GatherNdHelperResult {
942   int n_slices;
943   int slice_size;
944   int indices_nd;
945   std::vector<int> dims_to_count;
946 };
947 
948 // Returns common values being used on both `GatherNd` and `GatherNdString`.
GatherNdHelper(const RuntimeShape & params_shape,const RuntimeShape & indices_shape)949 inline GatherNdHelperResult GatherNdHelper(const RuntimeShape& params_shape,
950                                            const RuntimeShape& indices_shape) {
951   GatherNdHelperResult ret;
952   ret.n_slices = 1;
953   ret.slice_size = 1;
954   const int indices_dims = indices_shape.DimensionsCount();
955   ret.indices_nd = indices_shape.Dims(indices_dims - 1);
956   const int params_dims = params_shape.DimensionsCount();
957   for (int i = 0; i < indices_dims - 1; ++i) {
958     ret.n_slices *= indices_shape.Dims(i);
959   }
960   for (int i = ret.indices_nd; i < params_dims; ++i) {
961     ret.slice_size *= params_shape.Dims(i);
962   }
963 
964   int remain_flat_size = params_shape.FlatSize();
965   ret.dims_to_count = std::vector<int>(ret.indices_nd, 0);
966   for (int i = 0; i < ret.indices_nd; ++i) {
967     ret.dims_to_count[i] = remain_flat_size / params_shape.Dims(i);
968     remain_flat_size = ret.dims_to_count[i];
969   }
970 
971   return ret;
972 }
973 
974 template <typename ParamsT, typename IndicesT = int32>
GatherNd(const RuntimeShape & params_shape,const ParamsT * params_data,const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & output_shape,ParamsT * output_data)975 inline void GatherNd(const RuntimeShape& params_shape,
976                      const ParamsT* params_data,
977                      const RuntimeShape& indices_shape,
978                      const IndicesT* indices_data,
979                      const RuntimeShape& output_shape, ParamsT* output_data) {
980   ruy::profiler::ScopeLabel label("GatherNd");
981 
982   const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
983   for (int i = 0; i < res.n_slices; ++i) {
984     int from_pos = 0;
985     for (int j = 0; j < res.indices_nd; ++j) {
986       from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
987     }
988     std::memcpy(output_data + i * res.slice_size, params_data + from_pos,
989                 sizeof(ParamsT) * res.slice_size);
990   }
991 }
992 
993 #ifndef TF_LITE_STATIC_MEMORY
994 template <typename IndicesT = int32>
GatherNdString(const RuntimeShape & params_shape,const TfLiteTensor * params_data,const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & output_shape,TfLiteTensor * output_data)995 inline void GatherNdString(const RuntimeShape& params_shape,
996                            const TfLiteTensor* params_data,
997                            const RuntimeShape& indices_shape,
998                            const IndicesT* indices_data,
999                            const RuntimeShape& output_shape,
1000                            TfLiteTensor* output_data) {
1001   ruy::profiler::ScopeLabel label("GatherNdString");
1002 
1003   const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
1004   DynamicBuffer buffer;
1005   for (int i = 0; i < res.n_slices; ++i) {
1006     int from_pos = 0;
1007     for (int j = 0; j < res.indices_nd; ++j) {
1008       from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
1009     }
1010     for (int j = 0; j < res.slice_size; ++j) {
1011       buffer.AddString(GetString(params_data, from_pos + j));
1012     }
1013   }
1014   buffer.WriteToTensor(output_data, /*new_shape=*/nullptr);
1015 }
1016 #endif
1017 
1018 template <typename IndicesT, typename UpdatesT>
ScatterNd(const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & updates_shape,const UpdatesT * updates_data,const RuntimeShape & output_shape,UpdatesT * output_data)1019 inline void ScatterNd(const RuntimeShape& indices_shape,
1020                       const IndicesT* indices_data,
1021                       const RuntimeShape& updates_shape,
1022                       const UpdatesT* updates_data,
1023                       const RuntimeShape& output_shape, UpdatesT* output_data) {
1024   ruy::profiler::ScopeLabel label("ScatterNd");
1025 
1026   int n_slices = 1;
1027   int slice_size = 1;
1028   const int outer_dims = indices_shape.DimensionsCount() - 1;
1029   const int indices_nd = indices_shape.Dims(outer_dims);
1030   const int updates_dims = updates_shape.DimensionsCount();
1031   for (int i = 0; i < outer_dims; ++i) {
1032     n_slices *= indices_shape.Dims(i);
1033   }
1034   for (int i = outer_dims; i < updates_dims; ++i) {
1035     slice_size *= updates_shape.Dims(i);
1036   }
1037 
1038   int output_flat_size = output_shape.FlatSize();
1039   int remain_flat_size = output_flat_size;
1040   std::vector<int> dims_to_count(indices_nd, 0);
1041   for (int i = 0; i < indices_nd; ++i) {
1042     dims_to_count[i] = remain_flat_size / output_shape.Dims(i);
1043     remain_flat_size = dims_to_count[i];
1044   }
1045 
1046   memset(output_data, 0, sizeof(UpdatesT) * output_flat_size);
1047   for (int i = 0; i < n_slices; ++i) {
1048     int to_pos = 0;
1049     for (int j = 0; j < indices_nd; ++j) {
1050       IndicesT idx = indices_data[i * indices_nd + j];
1051       TFLITE_DCHECK(0 <= idx && idx < output_shape.Dims(j));
1052       to_pos += idx * dims_to_count[j];
1053     }
1054     for (int j = 0; j < slice_size; j++) {
1055       output_data[to_pos + j] += updates_data[i * slice_size + j];
1056     }
1057   }
1058 }
1059 
1060 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,SequentialTensorWriter<T> * writer)1061 inline void Slice(const tflite::SliceParams& op_params,
1062                   const RuntimeShape& input_shape,
1063                   const RuntimeShape& output_shape,
1064                   SequentialTensorWriter<T>* writer) {
1065   const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
1066   TFLITE_DCHECK_LE(op_params.begin_count, 5);
1067   TFLITE_DCHECK_LE(op_params.size_count, 5);
1068   const int begin_count = op_params.begin_count;
1069   const int size_count = op_params.size_count;
1070   // We front-pad the begin and size vectors.
1071   std::array<int, 5> start;
1072   std::array<int, 5> stop;
1073   for (int i = 0; i < 5; ++i) {
1074     int padded_i = 5 - i;
1075     start[i] =
1076         begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
1077     stop[i] =
1078         (size_count < padded_i || op_params.size[size_count - padded_i] == -1)
1079             ? ext_shape.Dims(i)
1080             : start[i] + op_params.size[size_count - padded_i];
1081   }
1082 
1083   for (int i0 = start[0]; i0 < stop[0]; ++i0) {
1084     for (int i1 = start[1]; i1 < stop[1]; ++i1) {
1085       for (int i2 = start[2]; i2 < stop[2]; ++i2) {
1086         for (int i3 = start[3]; i3 < stop[3]; ++i3) {
1087           for (int i4 = start[4]; i4 < stop[4]; ++i4) {
1088             writer->Write(Offset(ext_shape, i0, i1, i2, i3, i4));
1089           }
1090         }
1091       }
1092     }
1093   }
1094 }
1095 
1096 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)1097 inline void Slice(const tflite::SliceParams& op_params,
1098                   const RuntimeShape& input_shape, const T* input_data,
1099                   const RuntimeShape& output_shape, T* output_data) {
1100   SequentialTensorWriter<T> writer(input_data, output_data);
1101   return Slice(op_params, input_shape, output_shape, &writer);
1102 }
1103 
1104 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const TfLiteTensor * input,const RuntimeShape & output_shape,TfLiteTensor * output)1105 inline void Slice(const tflite::SliceParams& op_params,
1106                   const RuntimeShape& input_shape, const TfLiteTensor* input,
1107                   const RuntimeShape& output_shape, TfLiteTensor* output) {
1108   SequentialTensorWriter<T> writer(input, output);
1109   return Slice(op_params, input_shape, output_shape, &writer);
1110 }
1111 
1112 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1113 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
1114              const T* input2_data, const RuntimeShape& output_shape,
1115              T* output_data) {
1116   const int flat_size = MatchingFlatSize(input1_shape, output_shape);
1117 
1118   auto min_value = input2_data[0];
1119   for (int i = 0; i < flat_size; i++) {
1120     output_data[i] = input1_data[i] > min_value ? min_value : input1_data[i];
1121   }
1122 }
1123 
1124 // Convenience version that allows, for example, generated-code calls to be
1125 // the same as other binary ops.
1126 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1127 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
1128                     const RuntimeShape&, const T* input2_data,
1129                     const RuntimeShape& output_shape, T* output_data) {
1130   // Drop shape of second input: not needed.
1131   Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
1132 }
1133 
1134 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1135 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
1136              const T* input2_data, const RuntimeShape& output_shape,
1137              T* output_data) {
1138   const int flat_size = MatchingFlatSize(input1_shape, output_shape);
1139 
1140   auto max_value = input2_data[0];
1141   for (int i = 0; i < flat_size; i++) {
1142     output_data[i] = input1_data[i] < max_value ? max_value : input1_data[i];
1143   }
1144 }
1145 
1146 // Convenience version that allows, for example, generated-code calls to be
1147 // the same as other binary ops.
1148 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1149 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
1150                     const RuntimeShape&, const T* input2_data,
1151                     const RuntimeShape& output_shape, T* output_data) {
1152   // Drop shape of second input: not needed.
1153   Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
1154 }
1155 
1156 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)1157 void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
1158             const T3* input2_data, const RuntimeShape& output_shape,
1159             T2* output_data) {
1160   ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
1161             std::greater<T1>());
1162 }
1163 
1164 // Convenience version that allows, for example, generated-code calls to be
1165 // the same as other binary ops.
1166 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const RuntimeShape & input2_shape,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)1167 inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
1168                    const RuntimeShape& input2_shape, const T3* input2_data,
1169                    const RuntimeShape& output_shape, T2* output_data) {
1170   // Drop shape of second input: not needed.
1171   ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
1172 }
1173 
1174 template <typename D, typename T>
Select(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)1175 void Select(const RuntimeShape& input_condition_shape,
1176             const D* input_condition_data, const RuntimeShape& input_x_shape,
1177             const T* input_x_data, const RuntimeShape& input_y_shape,
1178             const T* input_y_data, const RuntimeShape& output_shape,
1179             T* output_data) {
1180   int64_t flatsize;
1181   // Allow select operator executions on mixed scalar tensors and one element
1182   // tensors.
1183   if (input_condition_shape.FlatSize() == 1 && input_x_shape.FlatSize() == 1 &&
1184       input_y_shape.FlatSize() == 1 && output_shape.FlatSize() == 1) {
1185     flatsize = 1;
1186   } else {
1187     flatsize = MatchingFlatSize(input_condition_shape, input_x_shape,
1188                                 input_y_shape, output_shape);
1189   }
1190   for (int64_t i = 0; i < flatsize; ++i) {
1191     output_data[i] =
1192         input_condition_data[i] ? input_x_data[i] : input_y_data[i];
1193   }
1194 }
1195 
1196 template <typename D, typename T>
RankOneSelect(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)1197 void RankOneSelect(const RuntimeShape& input_condition_shape,
1198                    const D* input_condition_data,
1199                    const RuntimeShape& input_x_shape, const T* input_x_data,
1200                    const RuntimeShape& input_y_shape, const T* input_y_data,
1201                    const RuntimeShape& output_shape, T* output_data) {
1202   const int64_t outer_size = input_condition_shape.FlatSize();
1203   int64_t inner_size;
1204   if (input_condition_shape.DimensionsCount() == 0) {
1205     inner_size = MatchingFlatSize(input_x_shape, input_y_shape, output_shape);
1206   } else {
1207     TFLITE_DCHECK_EQ(
1208         MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
1209         outer_size);
1210     inner_size =
1211         MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
1212   }
1213 
1214   int64_t offset = 0;
1215   for (int64_t i = 0; i < outer_size; i++) {
1216     const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
1217     memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
1218     offset += inner_size;
1219   }
1220 }
1221 
1222 template <typename D, typename T>
BroadcastSelect4DSlow(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)1223 void BroadcastSelect4DSlow(const RuntimeShape& input_condition_shape,
1224                            const D* input_condition_data,
1225                            const RuntimeShape& input_x_shape,
1226                            const T* input_x_data,
1227                            const RuntimeShape& input_y_shape,
1228                            const T* input_y_data,
1229                            const RuntimeShape& output_shape, T* output_data) {
1230   TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), 4);
1231   TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), 4);
1232   TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), 4);
1233   TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
1234 
1235   const RuntimeShape extended_output_shape =
1236       RuntimeShape::ExtendedShape(4, output_shape);
1237 
1238   NdArrayDesc<4> desc_condition;
1239   NdArrayDesc<4> desc_x;
1240   NdArrayDesc<4> desc_y;
1241   NdArrayDescsForElementwiseBroadcast(input_condition_shape, input_x_shape,
1242                                       input_y_shape, &desc_condition, &desc_x,
1243                                       &desc_y);
1244 
1245   // In Tensorflow, the dimensions are canonically named (batch_number, row,
1246   // col, channel), with extents (batches, height, width, depth), with the
1247   // trailing dimension changing most rapidly (channels has the smallest
1248   // stride, typically 1 element).
1249   //
1250   // In generated C code, we store arrays with the dimensions reversed. The
1251   // first dimension has smallest stride.
1252   //
1253   // We name our variables by their Tensorflow convention, but generate C code
1254   // nesting loops such that the innermost loop has the smallest stride for
1255   // the best cache behavior.
1256   for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
1257     for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
1258       for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
1259         for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
1260           const int condition_index =
1261               SubscriptToIndex(desc_condition, b, y, x, c);
1262           const int x_index = SubscriptToIndex(desc_x, b, y, x, c);
1263           const int y_index = SubscriptToIndex(desc_y, b, y, x, c);
1264           output_data[Offset(extended_output_shape, b, y, x, c)] =
1265               input_condition_data[condition_index] ? input_x_data[x_index]
1266                                                     : input_y_data[y_index];
1267         }
1268       }
1269     }
1270   }
1271 }
1272 
1273 template <typename D, typename T>
SelectTrueCoords(const RuntimeShape & input_condition_shape,const D * input_condition_data,T * output_data)1274 void SelectTrueCoords(const RuntimeShape& input_condition_shape,
1275                       const D* input_condition_data, T* output_data) {
1276   const size_t size = input_condition_shape.FlatSize();
1277   if (size == 0) {
1278     // Dimension is zero, in which case we don't need to output.
1279     return;
1280   }
1281   const size_t cond_rank = input_condition_shape.DimensionsCount();
1282 
1283   std::vector<int> dims_to_count(cond_rank, 0);
1284   int cur_flat_size = size;
1285   for (int i = 0; i < cond_rank; ++i) {
1286     dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i);
1287     cur_flat_size = dims_to_count[i];
1288   }
1289 
1290   int output_index = 0;
1291   for (int i = 0; i < size; ++i) {
1292     if (input_condition_data[i]) {
1293       // Insert the coordinate of the current item (row major) into output.
1294       int flat_index = i;
1295       for (int j = 0; j < cond_rank; ++j) {
1296         int coord_j = flat_index / dims_to_count[j];
1297         output_data[output_index * cond_rank + j] = coord_j;
1298         flat_index %= dims_to_count[j];
1299       }
1300       output_index++;
1301     }
1302   }
1303 }
1304 
1305 // For easy implementation, the indices is always a vector of size-4 vectors.
1306 template <typename T, typename TI>
SparseToDense(const std::vector<std::vector<TI>> & indices,const T * values,T default_value,bool value_is_scalar,const RuntimeShape & unextended_output_shape,T * output_data)1307 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
1308                           const T* values, T default_value,
1309                           bool value_is_scalar,
1310                           const RuntimeShape& unextended_output_shape,
1311                           T* output_data) {
1312   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1313   const RuntimeShape output_shape =
1314       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1315   const int value_count = indices.size();
1316 
1317   // First fill the output_data with default value.
1318   const int num_elements = output_shape.FlatSize();
1319   for (int i = 0; i < num_elements; ++i) {
1320     output_data[i] = default_value;
1321   }
1322 
1323   // Special handle for value is scalar case to avoid checking the boolean
1324   // condition within the loop every time.
1325   if (value_is_scalar) {
1326     for (int i = 0; i < value_count; ++i) {
1327       const std::vector<TI>& index = indices[i];
1328       TFLITE_DCHECK_EQ(index.size(), 4);
1329       const T value = *values;  // just use the first value.
1330       output_data[Offset(output_shape, index[0], index[1], index[2],
1331                          index[3])] = value;
1332     }
1333     return;
1334   }
1335 
1336   // Go through the values and indices to fill the sparse values.
1337   for (int i = 0; i < value_count; ++i) {
1338     const std::vector<TI>& index = indices[i];
1339     TFLITE_DCHECK_EQ(index.size(), 4);
1340     const T value = values[i];
1341     output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] =
1342         value;
1343   }
1344 }
1345 
1346 template <typename T>
Pow(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1347 inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
1348                 const RuntimeShape& input2_shape, const T* input2_data,
1349                 const RuntimeShape& output_shape, T* output_data) {
1350   const int flat_size =
1351       MatchingFlatSize(input1_shape, input2_shape, output_shape);
1352   for (int i = 0; i < flat_size; ++i) {
1353     output_data[i] = std::pow(input1_data[i], input2_data[i]);
1354   }
1355 }
1356 
1357 template <typename T>
BroadcastPow4DSlow(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)1358 inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape,
1359                                const T* input1_data,
1360                                const RuntimeShape& unextended_input2_shape,
1361                                const T* input2_data,
1362                                const RuntimeShape& unextended_output_shape,
1363                                T* output_data) {
1364   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
1365   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
1366   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1367   const RuntimeShape output_shape =
1368       RuntimeShape::ExtendedShape(4, unextended_output_shape);
1369 
1370   NdArrayDesc<4> desc1;
1371   NdArrayDesc<4> desc2;
1372   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
1373                                       unextended_input2_shape, &desc1, &desc2);
1374 
1375   for (int b = 0; b < output_shape.Dims(0); ++b) {
1376     for (int y = 0; y < output_shape.Dims(1); ++y) {
1377       for (int x = 0; x < output_shape.Dims(2); ++x) {
1378         for (int c = 0; c < output_shape.Dims(3); ++c) {
1379           auto out_idx = Offset(output_shape, b, y, x, c);
1380           auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
1381           auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
1382           auto in1_val = input1_data[in1_idx];
1383           auto in2_val = input2_data[in2_idx];
1384           output_data[out_idx] = std::pow(in1_val, in2_val);
1385         }
1386       }
1387     }
1388   }
1389 }
1390 
1391 template <typename Scalar>
Reverse(int axis,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * output_data)1392 void Reverse(int axis, const RuntimeShape& input_shape,
1393              const Scalar* input_data, const RuntimeShape& output_shape,
1394              Scalar* output_data) {
1395   ruy::profiler::ScopeLabel label("Reverse");
1396 
1397   int outer_size = 1;
1398   for (int i = 0; i < axis; ++i) {
1399     outer_size *= input_shape.Dims(i);
1400   }
1401 
1402   int copy_size = 1;
1403   for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) {
1404     copy_size *= input_shape.Dims(i);
1405   }
1406 
1407   const int dims_at_axis = input_shape.Dims(axis);
1408   for (int i = 0; i < outer_size; ++i) {
1409     for (int j = 0; j < dims_at_axis; ++j) {
1410       const int start_pos = (i * dims_at_axis + j) * copy_size;
1411       Scalar* output_ptr = output_data + start_pos;
1412       int loc = (i * dims_at_axis + dims_at_axis - j - 1) * copy_size;
1413       memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
1414     }
1415   }
1416 }
1417 
1418 template <typename Scalar, typename TS>
ReverseSequence(const TS * seq_lengths,const int seq_dim,const int batch_dim,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * output_data)1419 void ReverseSequence(const TS* seq_lengths, const int seq_dim,
1420                      const int batch_dim, const RuntimeShape& input_shape,
1421                      const Scalar* input_data, const RuntimeShape& output_shape,
1422                      Scalar* output_data) {
1423   ruy::profiler::ScopeLabel label("ReverseSequence");
1424 
1425   int outer_size = 1;
1426   int outer_dim = std::min(batch_dim, seq_dim);
1427   int medium_dim = std::max(batch_dim, seq_dim);
1428   for (int i = 0; i < outer_dim; ++i) {
1429     outer_size *= input_shape.Dims(i);
1430   }
1431 
1432   int medium_size = 1;
1433   for (int i = outer_dim + 1; i < medium_dim; ++i) {
1434     medium_size *= input_shape.Dims(i);
1435   }
1436 
1437   int copy_size = 1;
1438   for (int i = medium_dim + 1; i < input_shape.DimensionsCount(); ++i) {
1439     copy_size *= input_shape.Dims(i);
1440   }
1441 
1442   const int dims_at_outer_dim = input_shape.Dims(outer_dim);
1443   const int dims_at_medium_dim = input_shape.Dims(medium_dim);
1444 
1445   Scalar* output_ptr;
1446   if (batch_dim > seq_dim) {
1447     for (int i = 0; i < outer_size; ++i) {
1448       for (int j = 0; j < dims_at_outer_dim; ++j) {
1449         const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1450         for (int p = 0; p < medium_size; ++p) {
1451           for (int q = 0; q < dims_at_medium_dim; ++q) {
1452             const int in_pos =
1453                 ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1454             const Scalar* in_ptr = input_data + in_pos;
1455             int sl = seq_lengths[q] - 1;
1456             if (j > sl) {
1457               output_ptr = output_data + in_pos;
1458             } else {
1459               const int out_pos_base =
1460                   (i * dims_at_outer_dim + sl - j) * medium_size;
1461               const int out_pos =
1462                   ((out_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1463               output_ptr = output_data + out_pos;
1464             }
1465             memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
1466           }
1467         }
1468       }
1469     }
1470   } else if (batch_dim < seq_dim) {
1471     for (int i = 0; i < outer_size; ++i) {
1472       for (int j = 0; j < dims_at_outer_dim; ++j) {
1473         const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1474         int sl = seq_lengths[j] - 1;
1475         const int out_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1476         for (int p = 0; p < medium_size; ++p) {
1477           for (int q = 0; q < dims_at_medium_dim; ++q) {
1478             const int in_pos =
1479                 ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1480             const Scalar* in_ptr = input_data + in_pos;
1481             if (q > sl) {
1482               output_ptr = output_data + in_pos;
1483             } else {
1484               const int out_pos =
1485                   ((out_pos_base + p) * dims_at_medium_dim + sl - q) *
1486                   copy_size;
1487               output_ptr = output_data + out_pos;
1488             }
1489             memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
1490           }
1491         }
1492       }
1493     }
1494   }
1495 }
1496 
1497 template <typename T>
SegmentSum(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & segment_ids_shape,const int32_t * segment_ids_data,const RuntimeShape & output_shape,T * output_data)1498 inline void SegmentSum(const RuntimeShape& input_shape, const T* input_data,
1499                        const RuntimeShape& segment_ids_shape,
1500                        const int32_t* segment_ids_data,
1501                        const RuntimeShape& output_shape, T* output_data) {
1502   const int segment_flat_size =
1503       MatchingFlatSizeSkipDim(input_shape, 0, output_shape);
1504 
1505   memset(output_data, 0, sizeof(T) * output_shape.FlatSize());
1506 
1507   for (int i = 0; i < input_shape.Dims(0); i++) {
1508     int output_index = segment_ids_data[i];
1509     for (int j = 0; j < segment_flat_size; ++j) {
1510       output_data[output_index * segment_flat_size + j] +=
1511           input_data[i * segment_flat_size + j];
1512     }
1513   }
1514 }
1515 
1516 }  // namespace reference_ops
1517 }  // namespace tflite
1518 
1519 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
1520