1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
17
18 #include <stdint.h>
19 #include <sys/types.h>
20
21 #include <algorithm>
22 #include <array>
23 #include <cmath>
24 #include <cstring>
25 #include <functional>
26 #include <limits>
27 #include <memory>
28 #include <type_traits>
29
30 #include "Eigen/Core"
31 #include "fixedpoint/fixedpoint.h"
32 #include "ruy/profiler/instrumentation.h" // from @ruy
33 #include "tensorflow/lite/c/common.h"
34 #include "tensorflow/lite/kernels/internal/common.h"
35 #include "tensorflow/lite/kernels/internal/quantization_util.h"
36 #include "tensorflow/lite/kernels/internal/reference/add.h"
37 #include "tensorflow/lite/kernels/internal/reference/add_n.h"
38 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
39 #include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
40 #include "tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h"
41 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
42 #include "tensorflow/lite/kernels/internal/reference/cast.h"
43 #include "tensorflow/lite/kernels/internal/reference/ceil.h"
44 #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
45 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
46 #include "tensorflow/lite/kernels/internal/reference/conv.h"
47 #include "tensorflow/lite/kernels/internal/reference/depth_to_space.h"
48 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
49 #include "tensorflow/lite/kernels/internal/reference/div.h"
50 #include "tensorflow/lite/kernels/internal/reference/elu.h"
51 #include "tensorflow/lite/kernels/internal/reference/exp.h"
52 #include "tensorflow/lite/kernels/internal/reference/fill.h"
53 #include "tensorflow/lite/kernels/internal/reference/floor.h"
54 #include "tensorflow/lite/kernels/internal/reference/floor_div.h"
55 #include "tensorflow/lite/kernels/internal/reference/floor_mod.h"
56 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
57 #include "tensorflow/lite/kernels/internal/reference/gather.h"
58 #include "tensorflow/lite/kernels/internal/reference/hard_swish.h"
59 #include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
60 #include "tensorflow/lite/kernels/internal/reference/leaky_relu.h"
61 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
62 #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
63 #include "tensorflow/lite/kernels/internal/reference/mul.h"
64 #include "tensorflow/lite/kernels/internal/reference/neg.h"
65 #include "tensorflow/lite/kernels/internal/reference/pad.h"
66 #include "tensorflow/lite/kernels/internal/reference/pooling.h"
67 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
68 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
69 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
70 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
71 #include "tensorflow/lite/kernels/internal/reference/requantize.h"
72 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
73 #include "tensorflow/lite/kernels/internal/reference/round.h"
74 #include "tensorflow/lite/kernels/internal/reference/softmax.h"
75 #include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"
76 #include "tensorflow/lite/kernels/internal/reference/space_to_depth.h"
77 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
78 #include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
79 #include "tensorflow/lite/kernels/internal/reference/sub.h"
80 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
81 #include "tensorflow/lite/kernels/internal/reference/transpose.h"
82 #include "tensorflow/lite/kernels/internal/reference/transpose_conv.h"
83 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
84 #include "tensorflow/lite/kernels/internal/tensor.h"
85 #include "tensorflow/lite/kernels/internal/types.h"
86 namespace tflite {
87
88 namespace reference_ops {
89
90 template <typename T>
Relu(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)91 inline void Relu(const RuntimeShape& input_shape, const T* input_data,
92 const RuntimeShape& output_shape, T* output_data) {
93 const int flat_size = MatchingFlatSize(input_shape, output_shape);
94 for (int i = 0; i < flat_size; ++i) {
95 const T val = input_data[i];
96 const T lower = 0;
97 const T clamped = val < lower ? lower : val;
98 output_data[i] = clamped;
99 }
100 }
101
102 template <typename T>
Relu1(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)103 inline void Relu1(const RuntimeShape& input_shape, const T* input_data,
104 const RuntimeShape& output_shape, T* output_data) {
105 ruy::profiler::ScopeLabel label("Relu1 (not fused)");
106 const int flat_size = MatchingFlatSize(input_shape, output_shape);
107 for (int i = 0; i < flat_size; ++i) {
108 const T val = input_data[i];
109 const T upper = 1;
110 const T lower = -1;
111 const T clamped = val > upper ? upper : val < lower ? lower : val;
112 output_data[i] = clamped;
113 }
114 }
115
Relu6(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)116 inline void Relu6(const RuntimeShape& input_shape, const float* input_data,
117 const RuntimeShape& output_shape, float* output_data) {
118 ruy::profiler::ScopeLabel label("Relu6 (not fused)");
119 const int flat_size = MatchingFlatSize(input_shape, output_shape);
120 for (int i = 0; i < flat_size; ++i) {
121 const float val = input_data[i];
122 const float upper = 6;
123 const float lower = 0;
124 const float clamped = val > upper ? upper : val < lower ? lower : val;
125 output_data[i] = clamped;
126 }
127 }
128
129 template <typename T>
ReluX(const tflite::ReluParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)130 inline void ReluX(const tflite::ReluParams& params,
131 const RuntimeShape& input_shape, const T* input_data,
132 const RuntimeShape& output_shape, T* output_data) {
133 ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
134 const int flat_size = MatchingFlatSize(input_shape, output_shape);
135 for (int i = 0; i < flat_size; ++i) {
136 const int32 val = static_cast<int32_t>(input_data[i]);
137 int32 clamped = params.output_offset +
138 MultiplyByQuantizedMultiplier(val - params.input_offset,
139 params.output_multiplier,
140 params.output_shift);
141 clamped = std::max(params.quantized_activation_min, clamped);
142 clamped = std::min(params.quantized_activation_max, clamped);
143 output_data[i] = static_cast<T>(clamped);
144 }
145 }
146
147 template <typename T>
ReluX(const tflite::ActivationParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)148 inline void ReluX(const tflite::ActivationParams& params,
149 const RuntimeShape& input_shape, const T* input_data,
150 const RuntimeShape& output_shape, T* output_data) {
151 ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
152 const int flat_size = MatchingFlatSize(input_shape, output_shape);
153 const T max_value = params.quantized_activation_max;
154 const T min_value = params.quantized_activation_min;
155 for (int i = 0; i < flat_size; ++i) {
156 const T val = input_data[i];
157 const T clamped = val > max_value ? max_value
158 : val < min_value ? min_value
159 : val;
160 output_data[i] = clamped;
161 }
162 }
163
164 // TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
165 // dimensionality if the runtime code does a single loop over one dimension
166 // that handles broadcasting as the base case. The code generator would then
167 // generate max(D1, D2) nested for loops.
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)168 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
169 const RuntimeShape& unswitched_input1_shape,
170 const uint8* unswitched_input1_data,
171 const RuntimeShape& unswitched_input2_shape,
172 const uint8* unswitched_input2_data,
173 const RuntimeShape& output_shape,
174 uint8* output_data) {
175 ArithmeticParams switched_params = unswitched_params;
176 switched_params.input1_offset = unswitched_params.input2_offset;
177 switched_params.input2_offset = unswitched_params.input1_offset;
178
179 const bool use_unswitched =
180 unswitched_params.broadcast_category ==
181 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
182
183 const ArithmeticParams& params =
184 use_unswitched ? unswitched_params : switched_params;
185 const uint8* input1_data =
186 use_unswitched ? unswitched_input1_data : unswitched_input2_data;
187 const uint8* input2_data =
188 use_unswitched ? unswitched_input2_data : unswitched_input1_data;
189
190 // Fivefold nested loops. The second input resets its position for each
191 // iteration of the second loop. The first input resets its position at the
192 // beginning of the fourth loop. The innermost loop is an elementwise Mul of
193 // sections of the arrays.
194 uint8* output_data_ptr = output_data;
195 const uint8* input1_data_ptr = input1_data;
196 const uint8* input2_data_reset = input2_data;
197 int y0 = params.broadcast_shape[0];
198 int y1 = params.broadcast_shape[1];
199 int y2 = params.broadcast_shape[2];
200 int y3 = params.broadcast_shape[3];
201 int y4 = params.broadcast_shape[4];
202 for (int i0 = 0; i0 < y0; ++i0) {
203 const uint8* input2_data_ptr;
204 for (int i1 = 0; i1 < y1; ++i1) {
205 input2_data_ptr = input2_data_reset;
206 for (int i2 = 0; i2 < y2; ++i2) {
207 for (int i3 = 0; i3 < y3; ++i3) {
208 MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
209 output_data_ptr);
210 input2_data_ptr += y4;
211 output_data_ptr += y4;
212 }
213 input1_data_ptr += y4;
214 }
215 }
216 input2_data_reset = input2_data_ptr;
217 }
218 }
219
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)220 inline void Mul(const ArithmeticParams& params,
221 const RuntimeShape& input1_shape, const int16* input1_data,
222 const RuntimeShape& input2_shape, const int16* input2_data,
223 const RuntimeShape& output_shape, int16* output_data) {
224 ruy::profiler::ScopeLabel label("Mul/Int16");
225
226 const int flat_size =
227 MatchingElementsSize(input1_shape, input2_shape, output_shape);
228
229 for (int i = 0; i < flat_size; i++) {
230 // F0 uses 0 integer bits, range [-1, 1].
231 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
232
233 F0 unclamped_result =
234 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
235 output_data[i] = unclamped_result.raw();
236 }
237 }
238
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)239 inline void Mul(const ArithmeticParams& params,
240 const RuntimeShape& input1_shape, const int16* input1_data,
241 const RuntimeShape& input2_shape, const int16* input2_data,
242 const RuntimeShape& output_shape, uint8* output_data) {
243 ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
244 int32 output_offset = params.output_offset;
245 int32 output_activation_min = params.quantized_activation_min;
246 int32 output_activation_max = params.quantized_activation_max;
247 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
248
249 const int flat_size =
250 MatchingElementsSize(input1_shape, input2_shape, output_shape);
251
252 for (int i = 0; i < flat_size; i++) {
253 // F0 uses 0 integer bits, range [-1, 1].
254 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
255
256 F0 unclamped_result =
257 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
258 int16 rescaled_result =
259 gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
260 int16 clamped_result =
261 std::min<int16>(output_activation_max - output_offset, rescaled_result);
262 clamped_result =
263 std::max<int16>(output_activation_min - output_offset, clamped_result);
264 output_data[i] = output_offset + clamped_result;
265 }
266 }
267
Sub16(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16_t * input1_data,const RuntimeShape & input2_shape,const int16_t * input2_data,const RuntimeShape & output_shape,int16_t * output_data)268 inline void Sub16(const ArithmeticParams& params,
269 const RuntimeShape& input1_shape, const int16_t* input1_data,
270 const RuntimeShape& input2_shape, const int16_t* input2_data,
271 const RuntimeShape& output_shape, int16_t* output_data) {
272 ruy::profiler::ScopeLabel label("Sub/Int16");
273 const int input1_shift = params.input1_shift;
274 const int flat_size =
275 MatchingElementsSize(input1_shape, input2_shape, output_shape);
276 const int16 output_activation_min = params.quantized_activation_min;
277 const int16 output_activation_max = params.quantized_activation_max;
278
279 TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
280 TFLITE_DCHECK_LE(input1_shift, 0);
281 TFLITE_DCHECK_LE(params.input2_shift, 0);
282 const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
283 const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
284 const int input_right_shift =
285 input1_shift == 0 ? -params.input2_shift : -input1_shift;
286
287 if (input1_shift == 0) {
288 // F0 uses 0 integer bits, range [-1, 1].
289 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
290 for (int i = 0; i < flat_size; ++i) {
291 F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
292 F0 scaled_input = F0::FromRaw(
293 gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
294 F0 result = SaturatingSub(input_ready_scaled, scaled_input);
295 const int16 raw_output = result.raw();
296 const int16 clamped_output = std::min(
297 output_activation_max, std::max(output_activation_min, raw_output));
298 output_data[i] = clamped_output;
299 }
300 } else {
301 // F0 uses 0 integer bits, range [-1, 1].
302 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
303 for (int i = 0; i < flat_size; ++i) {
304 F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
305 F0 scaled_input = F0::FromRaw(
306 gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
307 F0 result = SaturatingSub(scaled_input, input_ready_scaled);
308 const int16 raw_output = result.raw();
309 const int16 clamped_output = std::min(
310 output_activation_max, std::max(output_activation_min, raw_output));
311 output_data[i] = clamped_output;
312 }
313 }
314 }
315
316 template <typename Scalar>
Pack(const PackParams & params,const RuntimeShape * const * input_shapes,const Scalar * const * input_data,const RuntimeShape & output_shape,Scalar * output_data)317 void Pack(const PackParams& params, const RuntimeShape* const* input_shapes,
318 const Scalar* const* input_data, const RuntimeShape& output_shape,
319 Scalar* output_data) {
320 ruy::profiler::ScopeLabel label("Pack");
321 const int dimensions = output_shape.DimensionsCount();
322 int axis = params.axis;
323 int inputs_count = params.inputs_count;
324
325 int outer_size = 1;
326 for (int i = 0; i < axis; i++) {
327 outer_size *= output_shape.Dims(i);
328 }
329 int copy_size = 1;
330 for (int i = params.axis + 1; i < dimensions; i++) {
331 copy_size *= output_shape.Dims(i);
332 }
333 TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
334
335 for (int i = 0; i < inputs_count; ++i) {
336 for (int k = 0; k < outer_size; k++) {
337 const Scalar* input_ptr = input_data[i] + copy_size * k;
338 int loc = k * inputs_count * copy_size + i * copy_size;
339 memcpy(output_data + loc, input_ptr, copy_size * sizeof(Scalar));
340 }
341 }
342 }
343
344 template <typename Scalar>
Unpack(const UnpackParams & params,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * const * output_datas)345 void Unpack(const UnpackParams& params, const RuntimeShape& input_shape,
346 const Scalar* input_data, const RuntimeShape& output_shape,
347 Scalar* const* output_datas) {
348 ruy::profiler::ScopeLabel label("Unpack");
349 const int dimensions = input_shape.DimensionsCount();
350 const int outputs_count = params.num_split;
351
352 int outer_size = 1;
353 int axis = params.axis;
354 if (axis < 0) {
355 axis += dimensions;
356 }
357 TFLITE_DCHECK_GE(axis, 0);
358 TFLITE_DCHECK_LT(axis, dimensions);
359 for (int i = 0; i < axis; ++i) {
360 outer_size *= input_shape.Dims(i);
361 }
362 int copy_size = 1;
363 for (int i = axis + 1; i < dimensions; ++i) {
364 copy_size *= input_shape.Dims(i);
365 }
366 TFLITE_DCHECK_EQ(output_shape.FlatSize(), copy_size * outer_size);
367
368 for (int i = 0; i < outputs_count; ++i) {
369 for (int k = 0; k < outer_size; k++) {
370 Scalar* output_ptr = output_datas[i] + copy_size * k;
371 int loc = k * outputs_count * copy_size + i * copy_size;
372 memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
373 }
374 }
375 }
376
377 template <typename Scalar>
PackWithScaling(const PackParams & params,const RuntimeShape * const * input_shapes,const uint8 * const * input_data,const RuntimeShape & output_shape,uint8 * output_data)378 void PackWithScaling(const PackParams& params,
379 const RuntimeShape* const* input_shapes,
380 const uint8* const* input_data,
381 const RuntimeShape& output_shape, uint8* output_data) {
382 ruy::profiler::ScopeLabel label("PackWithScaling");
383 const int dimensions = output_shape.DimensionsCount();
384 int axis = params.axis;
385 const int32* input_zeropoint = params.input_zeropoint;
386 const float* input_scale = params.input_scale;
387 int inputs_count = params.inputs_count;
388 const int32 output_zeropoint = params.output_zeropoint;
389 const float output_scale = params.output_scale;
390
391 int outer_size = 1;
392 for (int i = 0; i < axis; i++) {
393 outer_size *= output_shape.Dims(i);
394 }
395 int copy_size = 1;
396 for (int i = axis + 1; i < dimensions; i++) {
397 copy_size *= output_shape.Dims(i);
398 }
399 TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
400
401 Scalar* output_ptr = output_data;
402 const float inverse_output_scale = 1.f / output_scale;
403 for (int k = 0; k < outer_size; k++) {
404 for (int i = 0; i < inputs_count; ++i) {
405 if (input_zeropoint[i] == output_zeropoint &&
406 input_scale[i] == output_scale) {
407 memcpy(output_ptr, input_data[i] + k * copy_size,
408 copy_size * sizeof(Scalar));
409 } else {
410 assert(false);
411 const float scale = input_scale[i] * inverse_output_scale;
412 const float bias = -input_zeropoint[i] * scale;
413 auto input_ptr = input_data[i];
414 for (int j = 0; j < copy_size; ++j) {
415 const int32_t value =
416 static_cast<int32_t>(std::round(input_ptr[j] * scale + bias)) +
417 output_zeropoint;
418 output_ptr[j] =
419 static_cast<uint8_t>(std::max(std::min(255, value), 0));
420 }
421 }
422 output_ptr += copy_size;
423 }
424 }
425 }
426
427 template <typename Scalar>
DepthConcatenation(const ConcatenationParams & params,const RuntimeShape * const * input_shapes,const Scalar * const * input_data,const RuntimeShape & output_shape,Scalar * output_data)428 void DepthConcatenation(const ConcatenationParams& params,
429 const RuntimeShape* const* input_shapes,
430 const Scalar* const* input_data,
431 const RuntimeShape& output_shape, Scalar* output_data) {
432 ruy::profiler::ScopeLabel label("DepthConcatenation");
433 auto params_copy = params;
434 params_copy.axis = 3;
435 Concatenation(params_copy, input_shapes, input_data, output_shape,
436 output_data);
437 }
438
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const float * input_data,const RuntimeShape & unextended_prev_activ_shape,const float * prev_activ_data,const RuntimeShape & weights_shape,const float * weights_data,const RuntimeShape & unextended_bias_shape,const float * bias_data,const RuntimeShape & unextended_prev_state_shape,const float * prev_state_data,const RuntimeShape & unextended_output_state_shape,float * output_state_data,const RuntimeShape & unextended_output_activ_shape,float * output_activ_data,const RuntimeShape & unextended_concat_temp_shape,float * concat_temp_data,const RuntimeShape & unextended_activ_temp_shape,float * activ_temp_data)439 inline void LstmCell(
440 const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
441 const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
442 const float* prev_activ_data, const RuntimeShape& weights_shape,
443 const float* weights_data, const RuntimeShape& unextended_bias_shape,
444 const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
445 const float* prev_state_data,
446 const RuntimeShape& unextended_output_state_shape, float* output_state_data,
447 const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
448 const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
449 const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
450 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
451 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
452 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
453 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
454 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
455 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
456 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
457 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
458 const RuntimeShape input_shape =
459 RuntimeShape::ExtendedShape(4, unextended_input_shape);
460 const RuntimeShape prev_activ_shape =
461 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
462 const RuntimeShape bias_shape =
463 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
464 const RuntimeShape prev_state_shape =
465 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
466 const RuntimeShape output_state_shape =
467 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
468 const RuntimeShape output_activ_shape =
469 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
470 const RuntimeShape concat_temp_shape =
471 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
472 const RuntimeShape activ_temp_shape =
473 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
474 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
475
476 const int weights_dim_count = weights_shape.DimensionsCount();
477 const int batches =
478 MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
479 output_state_shape, 0, output_activ_shape, 0);
480 const int height =
481 MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
482 output_state_shape, 1, output_activ_shape, 1);
483 const int width =
484 MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
485 output_state_shape, 2, output_activ_shape, 2);
486 const int input_depth = input_shape.Dims(3);
487 const int prev_activ_depth = prev_activ_shape.Dims(3);
488 const int total_input_depth = prev_activ_depth + input_depth;
489 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
490 total_input_depth);
491 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
492 const int intern_activ_depth =
493 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
494 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
495 intern_activ_depth * total_input_depth);
496 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
497 const int output_depth =
498 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
499 3, output_activ_shape, 3);
500 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
501
502 // Concatenate prev_activ and input data together
503 std::vector<float const*> concat_input_arrays_data;
504 std::vector<RuntimeShape const*> concat_input_arrays_shapes;
505 concat_input_arrays_data.push_back(input_data);
506 concat_input_arrays_data.push_back(prev_activ_data);
507 concat_input_arrays_shapes.push_back(&input_shape);
508 concat_input_arrays_shapes.push_back(&prev_activ_shape);
509 tflite::ConcatenationParams concat_params;
510 concat_params.axis = 3;
511 concat_params.inputs_count = concat_input_arrays_data.size();
512 Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
513 &(concat_input_arrays_data[0]), concat_temp_shape,
514 concat_temp_data);
515
516 // Fully connected
517 tflite::FullyConnectedParams fc_params;
518 fc_params.float_activation_min = std::numeric_limits<float>::lowest();
519 fc_params.float_activation_max = std::numeric_limits<float>::max();
520 FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
521 weights_data, bias_shape, bias_data, activ_temp_shape,
522 activ_temp_data);
523
524 // Memory state update (the LSTM "guts")
525 for (int b = 0; b < batches; ++b) {
526 for (int w = 0; w < width; ++w) {
527 for (int h = 0; h < height; ++h) {
528 for (int c = 0; c < output_depth; ++c) {
529 const float input_gate =
530 1.f /
531 (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
532 0 * output_depth + c)]));
533 const float new_input = std::tanh(activ_temp_data[Offset(
534 activ_temp_shape, b, h, w, 1 * output_depth + c)]);
535 const float forget_gate =
536 1.f /
537 (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
538 2 * output_depth + c)]));
539 const float output_gate =
540 1.f /
541 (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
542 3 * output_depth + c)]));
543 const float new_state =
544 input_gate * new_input +
545 forget_gate *
546 prev_state_data[Offset(prev_state_shape, b, h, w, c)];
547 output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
548 output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
549 output_gate * std::tanh(new_state);
550 }
551 }
552 }
553 }
554 }
555
556 // Quantized LSTM cell implementation.
557 // The quantization of the input, output arrays is as follows:
558 // - The input activations are quantized as uint8 on the interval
559 // [-1, 127/128].
560 // The rationale for that is that is the natural interval for output
561 // activations (see next point) and these need to be concatenated together.
562 // We could accommodate different ranges by re-scaling, but we empirically
563 // found that setting the input activations range to be [-1, 127/128] in the
564 // first place, removing the need for re-scaling, greatly improves accuracy.
565 // - The output activations are quantized as uint8 on the interval
566 // [-1, 127/128].
567 // The rationale for that is that the definition of a LSTM cell makes them
568 // intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128]
569 // makes for simpler, more accurate fixed-point arithmetic.
570 // - The output-at-previous-timestep state array is obviously quantized as
571 // the output activations.
572 // - The internal LSTM memory (not the output-at-previous-timestep, the other
573 // internal state array) is int16-quantized and may use any power-of-two,
574 // symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call
575 // StateIntegerBits below, see the below discussion of that template
576 // parameter ("The StateIntegerBits template parameter").
577 // - The output of the internal fully-connected node is int16-quantized
578 // on the interval [-8, 8 * 32767/32768], the rationale for which is
579 // explained just below ("Why [-8, 8] for fully-connected output?").
580 //
581 //
582 // === The StateIntegerBits template parameter ===
583 //
584 // The StateIntegerBits template parameter controls the fixed-point format used
585 // to represent the internal memory of the LSTM cell (not the
586 // output-at-previous-timestep, the other internal state array). It's currently
587 // a template parameter so that the model can control that. The most typical
588 // value for StateIntegerBits is 4. Other plausible values are anywhere between
589 // 3 and 5. We might eventually standardize on a single supported value, e.g. 4,
590 // and drop that template parameter. The reason why it can't be a runtime
591 // parameter is that this controls the fixed-point format used, i.e. we need to
592 // generate actually different code based on it. In particular, we generate code
593 // for a fixed-point tanh() implementation for that format, which internally
594 // uses a fixed-point exp() implementation, which internally uses a
595 // barrel-shifter with a number of steps that depends on StateIntegerBits.
596 // Another consequence of that is that a higher value of StateIntegerBits
597 // results in a more expensive implementation (more barrel shifter steps
598 // needed).
599 //
600 //
601 // === Why [-8, 8] for fully-connected output? ===
602 //
603 // This array is only fed to Logistic and Tanh functions, for which
604 // the quantized implementation will want to use fixed-point arithmetic,
605 // requiring a power-of-two representation interval. Thus, we should right
606 // away quantize this array to a power-of-two interval; otherwise,
607 // implementation will need to rescale that, losing any benefit that a tighter
608 // representation interval might otherwise yield, while introducing some
609 // numerical error and computational overhead.
610 //
611 // Now, Logistic and Tanh
612 // are nearly constant (nearly equal to their horizontal asymptotes)
613 // outside of a small bounded interval around 0:
614 //
615 // Logistic(4) = 1 - 1.8e-2 Tanh(4) = 1 - 6.7e-4
616 // Logistic(8) = 1 - 3.4e-4 Tanh(8) = 1 - 2.3e-7
617 // Logistic(16) = 1 - 1.1e-7 Tanh(16) = 1 - 2.5e-14
618 //
619 // From this, we see that clamping to [-4, 4] would be too inaccurate
620 // (the error of 1.8e-2 on Logistic would be felt even in 8bit precision)
621 // while clamping to [-16, 16] would make no difference even in float32.
622 // However, for a fixed-point implementation in 16-bit integers, using 5
623 // integer bits to represent the [-16, 16] range would leave only 11
624 // fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive
625 // representable values. Notice that is higher than the
626 // worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic.
627 // Using [-8, 8] thus seems like the better compromise overall, enjoying
628 // an increment of 2.4e-4 between representable values and a worst-case
629 // clamping error of 3.4e-4, both better than the increment of 4.9e-4 with
630 // [-16, 16].
631 //
632 // Moreover, all other things being equal, it is nice to choose the narrower
633 // representation range, as that makes the implementation of fixed-point
634 // math functions a little cheaper (each integer bit requires an additional
635 // barrel-shifter atep in the implementation of exp(-x)). That is further
636 // reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make
637 // sense for 32-bit float or 32-bit fixed-point quantization, but we are
638 // aiming for 16-bit fixed-point quantization of these internal nodes here.
639 //
640 template <int StateIntegerBits>
LstmCell(const LstmCellParams & params,const RuntimeShape & unextended_input_shape,const uint8 * input_data_uint8,const RuntimeShape & unextended_prev_activ_shape,const uint8 * prev_activ_data_uint8,const RuntimeShape & weights_shape,const uint8 * weights_data_uint8,const RuntimeShape & unextended_bias_shape,const int32 * bias_data_int32,const RuntimeShape & unextended_prev_state_shape,const int16 * prev_state_data_int16,const RuntimeShape & unextended_output_state_shape,int16 * output_state_data_int16,const RuntimeShape & unextended_output_activ_shape,uint8 * output_activ_data_uint8,const RuntimeShape & unextended_concat_temp_shape,uint8 * concat_temp_data_uint8,const RuntimeShape & unextended_activ_temp_shape,int16 * activ_temp_data_int16,void * gemmlowp_context)641 inline void LstmCell(const LstmCellParams& params,
642 const RuntimeShape& unextended_input_shape,
643 const uint8* input_data_uint8,
644 const RuntimeShape& unextended_prev_activ_shape,
645 const uint8* prev_activ_data_uint8,
646 const RuntimeShape& weights_shape,
647 const uint8* weights_data_uint8,
648 const RuntimeShape& unextended_bias_shape,
649 const int32* bias_data_int32,
650 const RuntimeShape& unextended_prev_state_shape,
651 const int16* prev_state_data_int16,
652 const RuntimeShape& unextended_output_state_shape,
653 int16* output_state_data_int16,
654 const RuntimeShape& unextended_output_activ_shape,
655 uint8* output_activ_data_uint8,
656 const RuntimeShape& unextended_concat_temp_shape,
657 uint8* concat_temp_data_uint8,
658 const RuntimeShape& unextended_activ_temp_shape,
659 int16* activ_temp_data_int16, void* gemmlowp_context) {
660 (void)gemmlowp_context; // only used in optimized code.
661 int32 weights_zero_point = params.weights_zero_point;
662 int32 accum_multiplier = params.accum_multiplier;
663 int accum_shift = params.accum_shift;
664 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
665 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
666 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
667 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
668 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
669 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
670 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
671 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
672 const RuntimeShape input_shape =
673 RuntimeShape::ExtendedShape(4, unextended_input_shape);
674 const RuntimeShape prev_activ_shape =
675 RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
676 const RuntimeShape bias_shape =
677 RuntimeShape::ExtendedShape(4, unextended_bias_shape);
678 const RuntimeShape prev_state_shape =
679 RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
680 const RuntimeShape output_state_shape =
681 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
682 const RuntimeShape output_activ_shape =
683 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
684 const RuntimeShape concat_temp_shape =
685 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
686 const RuntimeShape activ_temp_shape =
687 RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
688 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
689
690 // Gather dimensions information, and perform consistency checks.
691 const int weights_dim_count = weights_shape.DimensionsCount();
692 const int outer_size = MatchingFlatSizeSkipDim(
693 input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
694 output_activ_shape);
695 const int input_depth = input_shape.Dims(3);
696 const int prev_activ_depth = prev_activ_shape.Dims(3);
697 const int total_input_depth = prev_activ_depth + input_depth;
698 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
699 total_input_depth);
700 const int intern_activ_depth =
701 MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
702 TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
703 intern_activ_depth * total_input_depth);
704 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
705 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
706 const int output_depth =
707 MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
708 3, output_activ_shape, 3);
709 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
710 const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
711 const int fc_output_depth =
712 MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
713 const int fc_accum_depth = total_input_depth;
714 TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
715
716 // Depth-concatenate prev_activ and input data together.
717 uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
718 prev_activ_data_uint8};
719 const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
720 &prev_activ_shape};
721 tflite::ConcatenationParams concat_params;
722 concat_params.axis = 3;
723 concat_params.inputs_count = 2;
724 Concatenation(concat_params, concat_input_arrays_shapes,
725 concat_input_arrays_data, concat_temp_shape,
726 concat_temp_data_uint8);
727
728 // Implementation of the fully connected node inside the LSTM cell.
729 // The operands are 8-bit integers, the accumulators are internally 32bit
730 // integers, and the output is 16-bit fixed-point with 3 integer bits so
731 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
732 // is explained in the function comment above.
733 for (int b = 0; b < fc_batches; ++b) {
734 for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
735 // Internal accumulation.
736 // Initialize accumulator with the bias-value.
737 int32 accum = bias_data_int32[out_c];
738 // Accumulation loop.
739 for (int d = 0; d < fc_accum_depth; ++d) {
740 int16 input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
741 int16 weights_val =
742 weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
743 accum += input_val * weights_val;
744 }
745 // Down-scale the final int32 accumulator to the scale used by our
746 // (16-bit, using 3 integer bits) fixed-point format. The quantized
747 // multiplier and shift here have been pre-computed offline
748 // (e.g. by toco).
749 accum =
750 MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
751 // Saturate, cast to int16, and store to the temporary activations array.
752 accum = std::max(-32768, std::min(32767, accum));
753 activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
754 }
755 }
756
757 // Rest of the LSTM cell: tanh and logistic math functions, and some adds
758 // and muls, all done in 16-bit fixed-point.
759 for (int b = 0; b < outer_size; ++b) {
760 for (int c = 0; c < output_depth; ++c) {
761 // Define the fixed-point data types that we will use here. All use
762 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
763 // They only differ by the number of integral vs. fractional bits,
764 // determining the range of values that they can represent.
765 //
766 // F0 uses 0 integer bits, range [-1, 1].
767 // This is the return type of math functions such as tanh, logistic,
768 // whose range is in [-1, 1].
769 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
770 // F3 uses 3 integer bits, range [-8, 8].
771 // This is the range of the previous fully-connected node's output,
772 // which is our input here.
773 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
774 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
775 // 2^StateIntegerBits]. It's used to represent the internal state, whose
776 // number of integer bits is currently dictated by the model. See comment
777 // on the StateIntegerBits template parameter above.
778 using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
779 // Implementation of input gate, using fixed-point logistic function.
780 F3 input_gate_input = F3::FromRaw(
781 activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
782 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
783 // Implementation of input modulation gate, using fixed-point tanh
784 // function.
785 F3 input_modulation_gate_input = F3::FromRaw(
786 activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
787 F0 input_modulation_gate_output =
788 gemmlowp::tanh(input_modulation_gate_input);
789 // Implementation of forget gate, using fixed-point logistic function.
790 F3 forget_gate_input = F3::FromRaw(
791 activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
792 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
793 // Implementation of output gate, using fixed-point logistic function.
794 F3 output_gate_input = F3::FromRaw(
795 activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
796 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
797 // Implementation of internal multiplication nodes, still in fixed-point.
798 F0 input_times_input_modulation =
799 input_gate_output * input_modulation_gate_output;
800 FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]);
801 FS prev_state_times_forget_state = forget_gate_output * prev_state;
802 // Implementation of internal addition node, saturating.
803 FS new_state = gemmlowp::SaturatingAdd(
804 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
805 prev_state_times_forget_state);
806 // Implementation of last internal Tanh node, still in fixed-point.
807 // Since a Tanh fixed-point implementation is specialized for a given
808 // number or integer bits, and each specialization can have a substantial
809 // code size, and we already used above a Tanh on an input with 3 integer
810 // bits, and per the table in the above function comment there is no
811 // significant accuracy to be lost by clamping to [-8, +8] for a
812 // 3-integer-bits representation, let us just do that. This helps people
813 // porting this to targets where code footprint must be minimized.
814 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
815 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
816 // Store the new internal state back to memory, as 16-bit integers.
817 // Note: here we store the original value with StateIntegerBits, not
818 // the rescaled 3-integer-bits value fed to tanh.
819 output_state_data_int16[b * output_depth + c] = new_state.raw();
820 // Down-scale the output activations to 8-bit integers, saturating,
821 // and store back to memory.
822 int16 rescaled_output_activ =
823 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
824 int16 clamped_output_activ =
825 std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
826 output_activ_data_uint8[b * output_depth + c] =
827 128 + clamped_output_activ;
828 }
829 }
830 }
831
832 template <typename Scalar>
Split(const SplitParams & params,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape * const * output_shapes,Scalar * const * output_data)833 void Split(const SplitParams& params, const RuntimeShape& input_shape,
834 const Scalar* input_data, const RuntimeShape* const* output_shapes,
835 Scalar* const* output_data) {
836 ruy::profiler::ScopeLabel label("Split");
837 const int split_dimensions = input_shape.DimensionsCount();
838 int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
839 int outputs_count = params.num_split;
840 TFLITE_DCHECK_LT(axis, split_dimensions);
841
842 int64_t split_size = 0;
843 for (int i = 0; i < outputs_count; i++) {
844 TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
845 for (int j = 0; j < split_dimensions; j++) {
846 if (j != axis) {
847 MatchingDim(*output_shapes[i], j, input_shape, j);
848 }
849 }
850 split_size += output_shapes[i]->Dims(axis);
851 }
852 TFLITE_DCHECK_EQ(split_size, input_shape.Dims(axis));
853 int64_t outer_size = 1;
854 for (int i = 0; i < axis; ++i) {
855 outer_size *= input_shape.Dims(i);
856 }
857 // For all output arrays,
858 // FlatSize() = outer_size * Dims(axis) * base_inner_size;
859 int64_t base_inner_size = 1;
860 for (int i = axis + 1; i < split_dimensions; ++i) {
861 base_inner_size *= input_shape.Dims(i);
862 }
863
864 const Scalar* input_ptr = input_data;
865 for (int k = 0; k < outer_size; k++) {
866 for (int i = 0; i < outputs_count; ++i) {
867 const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
868 memcpy(output_data[i] + k * copy_size, input_ptr,
869 copy_size * sizeof(Scalar));
870 input_ptr += copy_size;
871 }
872 }
873 }
874
NodeOffset(int b,int h,int w,int height,int width)875 inline int NodeOffset(int b, int h, int w, int height, int width) {
876 return (b * height + h) * width + w;
877 }
878
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)879 inline void LocalResponseNormalization(
880 const tflite::LocalResponseNormalizationParams& op_params,
881 const RuntimeShape& input_shape, const float* input_data,
882 const RuntimeShape& output_shape, float* output_data) {
883 const int trailing_dim = input_shape.DimensionsCount() - 1;
884 const int outer_size =
885 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
886 const int depth =
887 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
888
889 for (int i = 0; i < outer_size; ++i) {
890 for (int c = 0; c < depth; ++c) {
891 const int begin_input_c = std::max(0, c - op_params.range);
892 const int end_input_c = std::min(depth, c + op_params.range);
893 float accum = 0.f;
894 for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
895 const float input_val = input_data[i * depth + input_c];
896 accum += input_val * input_val;
897 }
898 const float multiplier =
899 std::pow(op_params.bias + op_params.alpha * accum, -op_params.beta);
900 output_data[i * depth + c] = input_data[i * depth + c] * multiplier;
901 }
902 }
903 }
904
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)905 inline void LogSoftmax(const SoftmaxParams& params,
906 const RuntimeShape& input_shape, const float* input_data,
907 const RuntimeShape& output_shape, float* output_data) {
908 const int trailing_dim = input_shape.DimensionsCount() - 1;
909 const int outer_size =
910 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
911 const int depth =
912 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
913
914 for (int i = 0; i < outer_size; ++i) {
915 // Find max element value which we'll use to ensure numerical stability
916 // taking advantage of the following equality:
917 // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
918 float max = std::numeric_limits<float>::lowest();
919 for (int c = 0; c < depth; ++c) {
920 max = std::max(max, input_data[i * depth + c]);
921 }
922
923 // Compute sum.
924 float sum = 0.f;
925 for (int c = 0; c < depth; ++c) {
926 sum += std::exp(input_data[i * depth + c] - max);
927 }
928
929 // Compute result.
930 const float log_sum = std::log(sum);
931 for (int c = 0; c < depth; ++c) {
932 output_data[i * depth + c] = input_data[i * depth + c] - max - log_sum;
933 }
934 }
935 }
936
LogSoftmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)937 inline void LogSoftmax(const SoftmaxParams& params,
938 const RuntimeShape& input_shape, const uint8* input_data,
939 const RuntimeShape& output_shape, uint8* output_data) {
940 ruy::profiler::ScopeLabel label("LogSoftmax/8bit");
941 const int32 input_multiplier = params.input_multiplier;
942 const int32 input_left_shift = params.input_left_shift;
943 const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
944 const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
945 const int diff_min = params.diff_min;
946 // The representation chosen for the input to the exp() function is Q5.26.
947 // We need to leave extra space since values that we skip might be as large
948 // as -32 before multiplying by input_beta_multiplier, and therefore as
949 // large as -16 afterwards. Note that exp(-8) is definitely not
950 // insignificant to accumulation, but exp(-16) definitely is.
951 static constexpr int kScaledDiffIntegerBits = 5;
952 static constexpr int kAccumulationIntegerBits = 12;
953 static constexpr int kOutputIntegerBits = 4;
954 using FixedPointScaledDiff =
955 gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
956 using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
957
958 const int trailing_dim = input_shape.DimensionsCount() - 1;
959 const int outer_size =
960 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
961 const int depth =
962 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
963
964 for (int i = 0; i < outer_size; ++i) {
965 uint8 max_in_row = 0;
966 for (int c = 0; c < depth; ++c) {
967 max_in_row = std::max(max_in_row, input_data[i * depth + c]);
968 }
969
970 FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
971 for (int c = 0; c < depth; ++c) {
972 int32 input_diff =
973 static_cast<int32>(input_data[i * depth + c]) - max_in_row;
974 if (input_diff >= diff_min) {
975 const int32 input_diff_rescaled =
976 MultiplyByQuantizedMultiplierGreaterThanOne(
977 input_diff, input_multiplier, input_left_shift);
978 const FixedPointScaledDiff scaled_diff_f8 =
979 FixedPointScaledDiff::FromRaw(input_diff_rescaled);
980 sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
981 exp_on_negative_values(scaled_diff_f8));
982 }
983 }
984
985 const int32 fixed_log_sum_of_exps =
986 log_x_for_x_greater_than_or_equal_to_1<kScaledDiffIntegerBits>(
987 sum_of_exps)
988 .raw();
989
990 // rescaled_diff_min is smallest representable in
991 // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the
992 // log-sub-exps that will be subtracted in the loop.
993 //
994 // The thresholds diff_min, etc are negative.
995 const int rescaled_diff_min =
996 fixed_log_sum_of_exps + std::numeric_limits<int32>::lowest();
997 const int adjusted_diff_min =
998 std::max(diff_min - 1, // Note use of > below instead of >= above.
999 MultiplyByQuantizedMultiplierSmallerThanOneExp(
1000 rescaled_diff_min, reverse_scaling_divisor,
1001 -reverse_scaling_right_shift));
1002
1003 for (int c = 0; c < depth; ++c) {
1004 int32 input_diff =
1005 static_cast<int32>(input_data[i * depth + c]) - max_in_row;
1006 if (input_diff > adjusted_diff_min) {
1007 const int32 input_diff_rescaled =
1008 MultiplyByQuantizedMultiplierGreaterThanOne(
1009 input_diff, input_multiplier, input_left_shift);
1010 int32 unsat_output =
1011 gemmlowp::RoundingDivideByPOT(
1012 (input_diff_rescaled - fixed_log_sum_of_exps),
1013 31 - kScaledDiffIntegerBits - kOutputIntegerBits) +
1014 255;
1015
1016 output_data[i * depth + c] = static_cast<uint8>(
1017 std::max(std::min(unsat_output, static_cast<int32>(255)), 0));
1018 } else {
1019 // Set output to smallest value.
1020 output_data[i * depth + c] = 0;
1021 }
1022 }
1023 }
1024 }
1025
Dequantize(const RuntimeShape & input_shape,const Eigen::half * input_data,const RuntimeShape & output_shape,float * output_data)1026 inline void Dequantize(const RuntimeShape& input_shape,
1027 const Eigen::half* input_data,
1028 const RuntimeShape& output_shape, float* output_data) {
1029 const int flat_size = MatchingFlatSize(input_shape, output_shape);
1030 for (int i = 0; i < flat_size; i++) {
1031 output_data[i] = Eigen::half_impl::half_to_float(input_data[i]);
1032 }
1033 }
1034
FakeQuant(const tflite::FakeQuantParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)1035 inline void FakeQuant(const tflite::FakeQuantParams& op_params,
1036 const RuntimeShape& input_shape, const float* input_data,
1037 const RuntimeShape& output_shape, float* output_data) {
1038 ruy::profiler::ScopeLabel label("FakeQuant");
1039 float rmin = op_params.minmax.min;
1040 float rmax = op_params.minmax.max;
1041 int num_bits = op_params.num_bits;
1042 // 0 should always be a representable value. Let's assume that the initial
1043 // min,max range contains 0.
1044 TFLITE_DCHECK_LE(rmin, 0.0f);
1045 TFLITE_DCHECK_GE(rmax, 0.0f);
1046 TFLITE_DCHECK_LT(rmin, rmax);
1047
1048 // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor.
1049 int quant_min = 0;
1050 int quant_max = (1 << num_bits) - 1;
1051 float nudged_min, nudged_max, nudged_scale;
1052 NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
1053 &nudged_max, &nudged_scale);
1054 const int flat_size = MatchingFlatSize(input_shape, output_shape);
1055 FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data,
1056 output_data, flat_size);
1057 }
1058
1059 // Common subroutine for both `GatherNd` and `GatherNdString`.
1060 struct GatherNdHelperResult {
1061 int n_slices;
1062 int slice_size;
1063 int indices_nd;
1064 std::vector<int> dims_to_count;
1065 };
1066
1067 // Returns common values being used on both `GatherNd` and `GatherNdString`.
GatherNdHelper(const RuntimeShape & params_shape,const RuntimeShape & indices_shape)1068 inline GatherNdHelperResult GatherNdHelper(const RuntimeShape& params_shape,
1069 const RuntimeShape& indices_shape) {
1070 GatherNdHelperResult ret;
1071 ret.n_slices = 1;
1072 ret.slice_size = 1;
1073 const int indices_dims = indices_shape.DimensionsCount();
1074 ret.indices_nd = indices_shape.Dims(indices_dims - 1);
1075 const int params_dims = params_shape.DimensionsCount();
1076 for (int i = 0; i < indices_dims - 1; ++i) {
1077 ret.n_slices *= indices_shape.Dims(i);
1078 }
1079 for (int i = ret.indices_nd; i < params_dims; ++i) {
1080 ret.slice_size *= params_shape.Dims(i);
1081 }
1082
1083 int remain_flat_size = params_shape.FlatSize();
1084 ret.dims_to_count = std::vector<int>(ret.indices_nd, 0);
1085 for (int i = 0; i < ret.indices_nd; ++i) {
1086 ret.dims_to_count[i] = remain_flat_size / params_shape.Dims(i);
1087 remain_flat_size = ret.dims_to_count[i];
1088 }
1089
1090 return ret;
1091 }
1092
1093 template <typename ParamsT, typename IndicesT = int32>
GatherNd(const RuntimeShape & params_shape,const ParamsT * params_data,const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & output_shape,ParamsT * output_data)1094 inline void GatherNd(const RuntimeShape& params_shape,
1095 const ParamsT* params_data,
1096 const RuntimeShape& indices_shape,
1097 const IndicesT* indices_data,
1098 const RuntimeShape& output_shape, ParamsT* output_data) {
1099 ruy::profiler::ScopeLabel label("GatherNd");
1100
1101 const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
1102 for (int i = 0; i < res.n_slices; ++i) {
1103 int from_pos = 0;
1104 for (int j = 0; j < res.indices_nd; ++j) {
1105 from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
1106 }
1107 std::memcpy(output_data + i * res.slice_size, params_data + from_pos,
1108 sizeof(ParamsT) * res.slice_size);
1109 }
1110 }
1111
1112 #ifndef TF_LITE_STATIC_MEMORY
1113 template <typename IndicesT = int32>
GatherNdString(const RuntimeShape & params_shape,const TfLiteTensor * params_data,const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & output_shape,TfLiteTensor * output_data)1114 inline void GatherNdString(const RuntimeShape& params_shape,
1115 const TfLiteTensor* params_data,
1116 const RuntimeShape& indices_shape,
1117 const IndicesT* indices_data,
1118 const RuntimeShape& output_shape,
1119 TfLiteTensor* output_data) {
1120 ruy::profiler::ScopeLabel label("GatherNdString");
1121
1122 const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
1123 DynamicBuffer buffer;
1124 for (int i = 0; i < res.n_slices; ++i) {
1125 int from_pos = 0;
1126 for (int j = 0; j < res.indices_nd; ++j) {
1127 from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
1128 }
1129 for (int j = 0; j < res.slice_size; ++j) {
1130 buffer.AddString(GetString(params_data, from_pos + j));
1131 }
1132 }
1133 buffer.WriteToTensor(output_data, /*new_shape=*/nullptr);
1134 }
1135 #endif
1136
1137 template <typename IndicesT, typename UpdatesT>
ScatterNd(const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & updates_shape,const UpdatesT * updates_data,const RuntimeShape & output_shape,UpdatesT * output_data)1138 inline void ScatterNd(const RuntimeShape& indices_shape,
1139 const IndicesT* indices_data,
1140 const RuntimeShape& updates_shape,
1141 const UpdatesT* updates_data,
1142 const RuntimeShape& output_shape, UpdatesT* output_data) {
1143 ruy::profiler::ScopeLabel label("ScatterNd");
1144
1145 int n_slices = 1;
1146 int slice_size = 1;
1147 const int outer_dims = indices_shape.DimensionsCount() - 1;
1148 const int indices_nd = indices_shape.Dims(outer_dims);
1149 const int updates_dims = updates_shape.DimensionsCount();
1150 for (int i = 0; i < outer_dims; ++i) {
1151 n_slices *= indices_shape.Dims(i);
1152 }
1153 for (int i = outer_dims; i < updates_dims; ++i) {
1154 slice_size *= updates_shape.Dims(i);
1155 }
1156
1157 int output_flat_size = output_shape.FlatSize();
1158 int remain_flat_size = output_flat_size;
1159 std::vector<int> dims_to_count(indices_nd, 0);
1160 for (int i = 0; i < indices_nd; ++i) {
1161 dims_to_count[i] = remain_flat_size / output_shape.Dims(i);
1162 remain_flat_size = dims_to_count[i];
1163 }
1164
1165 memset(output_data, 0, sizeof(UpdatesT) * output_flat_size);
1166 for (int i = 0; i < n_slices; ++i) {
1167 int to_pos = 0;
1168 for (int j = 0; j < indices_nd; ++j) {
1169 IndicesT idx = indices_data[i * indices_nd + j];
1170 TFLITE_DCHECK(0 <= idx && idx < output_shape.Dims(j));
1171 to_pos += idx * dims_to_count[j];
1172 }
1173 for (int j = 0; j < slice_size; j++) {
1174 output_data[to_pos + j] += updates_data[i * slice_size + j];
1175 }
1176 }
1177 }
1178
ComputeInterpolationValues(const float value,const float scale,const bool half_pixel_centers,int32 input_size,float * scaled_value,int32 * lower_bound,int32 * upper_bound)1179 inline void ComputeInterpolationValues(const float value, const float scale,
1180 const bool half_pixel_centers,
1181 int32 input_size, float* scaled_value,
1182 int32* lower_bound, int32* upper_bound) {
1183 if (half_pixel_centers) {
1184 *scaled_value = (value + 0.5f) * scale - 0.5f;
1185 } else {
1186 *scaled_value = value * scale;
1187 }
1188 float scaled_value_floor = std::floor(*scaled_value);
1189 *lower_bound =
1190 std::max(static_cast<int32>(scaled_value_floor), static_cast<int32>(0));
1191 *upper_bound =
1192 std::min(static_cast<int32>(std::ceil(*scaled_value)), input_size - 1);
1193 }
1194
1195 template <typename T>
ResizeBilinear(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,T * output_data)1196 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
1197 const RuntimeShape& unextended_input_shape,
1198 const T* input_data,
1199 const RuntimeShape& unextended_output_size_shape,
1200 const int32* output_size_data,
1201 const RuntimeShape& unextended_output_shape,
1202 T* output_data) {
1203 // If half_pixel_centers is True, align_corners must be False.
1204 TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
1205 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1206 TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
1207 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1208 const RuntimeShape input_shape =
1209 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1210 const RuntimeShape output_size_shape =
1211 RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
1212 const RuntimeShape output_shape =
1213 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1214
1215 int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
1216 int32 input_height = input_shape.Dims(1);
1217 int32 input_width = input_shape.Dims(2);
1218 int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
1219
1220 TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
1221 TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
1222 TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
1223 TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
1224 int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
1225 int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
1226
1227 float height_scale = static_cast<float>(input_height) / output_height;
1228 float width_scale = static_cast<float>(input_width) / output_width;
1229 if (op_params.align_corners && output_height > 1) {
1230 height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
1231 }
1232 if (op_params.align_corners && output_width > 1) {
1233 width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
1234 }
1235
1236 for (int b = 0; b < batches; ++b) {
1237 for (int y = 0; y < output_height; ++y) {
1238 float input_y;
1239 int32 y0, y1;
1240 ComputeInterpolationValues(y, height_scale, op_params.half_pixel_centers,
1241 input_height, &input_y, &y0, &y1);
1242 for (int x = 0; x < output_width; ++x) {
1243 float input_x;
1244 int32 x0, x1;
1245 ComputeInterpolationValues(x, width_scale, op_params.half_pixel_centers,
1246 input_width, &input_x, &x0, &x1);
1247 for (int c = 0; c < depth; ++c) {
1248 T interpolation =
1249 static_cast<T>(input_data[Offset(input_shape, b, y0, x0, c)] *
1250 (1 - (input_y - y0)) * (1 - (input_x - x0)) +
1251 input_data[Offset(input_shape, b, y1, x0, c)] *
1252 (input_y - y0) * (1 - (input_x - x0)) +
1253 input_data[Offset(input_shape, b, y0, x1, c)] *
1254 (1 - (input_y - y0)) * (input_x - x0) +
1255 input_data[Offset(input_shape, b, y1, x1, c)] *
1256 (input_y - y0) * (input_x - x0));
1257 output_data[Offset(output_shape, b, y, x, c)] = interpolation;
1258 }
1259 }
1260 }
1261 }
1262 }
1263
ComputeInterpolationValues(const int32 value,const int32 scale_10,const bool half_pixel_centers,int32 input_size,int32 * scaled_value,int32 * lower_bound,int32 * upper_bound)1264 inline void ComputeInterpolationValues(const int32 value, const int32 scale_10,
1265 const bool half_pixel_centers,
1266 int32 input_size, int32* scaled_value,
1267 int32* lower_bound, int32* upper_bound) {
1268 if (half_pixel_centers) {
1269 *scaled_value = value * scale_10 + scale_10 / 2 - (1 << 9);
1270 } else {
1271 *scaled_value = value * scale_10;
1272 }
1273 *lower_bound = std::max(*scaled_value / (1 << 10), 0);
1274 *upper_bound =
1275 std::min((*scaled_value + (1 << 10) - 1) / (1 << 10), input_size - 1);
1276 }
1277
1278 // Same as above but doesn't use any floating-point for the resize
1279 template <typename T>
ResizeBilinearInteger(const tflite::ResizeBilinearParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & unextended_output_size_shape,const int32 * output_size_data,const RuntimeShape & unextended_output_shape,T * output_data)1280 inline void ResizeBilinearInteger(
1281 const tflite::ResizeBilinearParams& op_params,
1282 const RuntimeShape& unextended_input_shape, const T* input_data,
1283 const RuntimeShape& unextended_output_size_shape,
1284 const int32* output_size_data, const RuntimeShape& unextended_output_shape,
1285 T* output_data) {
1286 // If half_pixel_centers is True, align_corners must be False.
1287 TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
1288 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
1289 TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
1290 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1291 const RuntimeShape input_shape =
1292 RuntimeShape::ExtendedShape(4, unextended_input_shape);
1293 const RuntimeShape output_size_shape =
1294 RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
1295 const RuntimeShape output_shape =
1296 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1297
1298 const int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
1299 const int32 input_height = input_shape.Dims(1);
1300 const int32 input_width = input_shape.Dims(2);
1301 const int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
1302
1303 TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
1304 TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
1305 TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
1306 TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
1307 const int32 output_height =
1308 output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
1309 const int32 output_width =
1310 output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
1311
1312 int32 height_scale_10 =
1313 ((1 << 10) * input_height + output_height / 2) / output_height;
1314 int32 width_scale_10 =
1315 ((1 << 10) * input_width + output_width / 2) / output_width;
1316 if (op_params.align_corners && output_height > 1) {
1317 height_scale_10 =
1318 ((1 << 10) * (input_height - 1) + (output_height - 1) / 2) /
1319 (output_height - 1);
1320 }
1321 if (op_params.align_corners && output_width > 1) {
1322 width_scale_10 = ((1 << 10) * (input_width - 1) + (output_width - 1) / 2) /
1323 (output_width - 1);
1324 }
1325
1326 for (int b = 0; b < batches; ++b) {
1327 for (int y = 0; y < output_height; ++y) {
1328 int32 input_y, y0, y1;
1329 ComputeInterpolationValues(y, height_scale_10,
1330 op_params.half_pixel_centers, input_height,
1331 &input_y, &y0, &y1);
1332 for (int x = 0; x < output_width; ++x) {
1333 int32 input_x, x0, x1;
1334 ComputeInterpolationValues(x, width_scale_10,
1335 op_params.half_pixel_centers, input_width,
1336 &input_x, &x0, &x1);
1337 for (int c = 0; c < depth; ++c) {
1338 const int64_t output_20_ll =
1339 static_cast<int64_t>(
1340 input_data[Offset(input_shape, b, y0, x0, c)]) *
1341 ((1 << 10) - (input_y - (1 << 10) * y0)) *
1342 ((1 << 10) - (input_x - (1 << 10) * x0));
1343 const int64_t output_20_lu =
1344 static_cast<int64_t>(
1345 input_data[Offset(input_shape, b, y1, x0, c)]) *
1346 (input_y - (1 << 10) * y0) *
1347 ((1 << 10) - (input_x - (1 << 10) * x0));
1348 const int64_t output_20_rl =
1349 static_cast<int64_t>(
1350 input_data[Offset(input_shape, b, y0, x1, c)]) *
1351 ((1 << 10) - (input_y - (1 << 10) * y0)) *
1352 (input_x - (1 << 10) * x0);
1353 const int64_t output_20_ru =
1354 static_cast<int64_t>(
1355 input_data[Offset(input_shape, b, y1, x1, c)]) *
1356 (input_y - (1 << 10) * y0) * (input_x - (1 << 10) * x0);
1357 const int64_t output_20 =
1358 output_20_ll + output_20_lu + output_20_rl + output_20_ru;
1359 const int64_t round = (output_20 > 0) ? (1 << 19) : -(1 << 19);
1360 const T interpolation =
1361 static_cast<T>((output_20 + round) / (1 << 20));
1362 output_data[Offset(output_shape, b, y, x, c)] = interpolation;
1363 }
1364 }
1365 }
1366 }
1367 }
1368
1369 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,SequentialTensorWriter<T> * writer)1370 inline void Slice(const tflite::SliceParams& op_params,
1371 const RuntimeShape& input_shape,
1372 const RuntimeShape& output_shape,
1373 SequentialTensorWriter<T>* writer) {
1374 const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
1375 TFLITE_DCHECK_LE(op_params.begin_count, 5);
1376 TFLITE_DCHECK_LE(op_params.size_count, 5);
1377 const int begin_count = op_params.begin_count;
1378 const int size_count = op_params.size_count;
1379 // We front-pad the begin and size vectors.
1380 std::array<int, 5> start;
1381 std::array<int, 5> stop;
1382 for (int i = 0; i < 5; ++i) {
1383 int padded_i = 5 - i;
1384 start[i] =
1385 begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
1386 stop[i] =
1387 (size_count < padded_i || op_params.size[size_count - padded_i] == -1)
1388 ? ext_shape.Dims(i)
1389 : start[i] + op_params.size[size_count - padded_i];
1390 }
1391
1392 for (int i0 = start[0]; i0 < stop[0]; ++i0) {
1393 for (int i1 = start[1]; i1 < stop[1]; ++i1) {
1394 for (int i2 = start[2]; i2 < stop[2]; ++i2) {
1395 for (int i3 = start[3]; i3 < stop[3]; ++i3) {
1396 for (int i4 = start[4]; i4 < stop[4]; ++i4) {
1397 writer->Write(Offset(ext_shape, i0, i1, i2, i3, i4));
1398 }
1399 }
1400 }
1401 }
1402 }
1403 }
1404
1405 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)1406 inline void Slice(const tflite::SliceParams& op_params,
1407 const RuntimeShape& input_shape, const T* input_data,
1408 const RuntimeShape& output_shape, T* output_data) {
1409 SequentialTensorWriter<T> writer(input_data, output_data);
1410 return Slice(op_params, input_shape, output_shape, &writer);
1411 }
1412
1413 template <typename T>
Slice(const tflite::SliceParams & op_params,const RuntimeShape & input_shape,const TfLiteTensor * input,const RuntimeShape & output_shape,TfLiteTensor * output)1414 inline void Slice(const tflite::SliceParams& op_params,
1415 const RuntimeShape& input_shape, const TfLiteTensor* input,
1416 const RuntimeShape& output_shape, TfLiteTensor* output) {
1417 SequentialTensorWriter<T> writer(input, output);
1418 return Slice(op_params, input_shape, output_shape, &writer);
1419 }
1420
1421 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1422 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
1423 const T* input2_data, const RuntimeShape& output_shape,
1424 T* output_data) {
1425 const int flat_size = MatchingFlatSize(input1_shape, output_shape);
1426
1427 auto min_value = input2_data[0];
1428 for (int i = 0; i < flat_size; i++) {
1429 output_data[i] = input1_data[i] > min_value ? min_value : input1_data[i];
1430 }
1431 }
1432
1433 // Convenience version that allows, for example, generated-code calls to be
1434 // the same as other binary ops.
1435 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1436 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
1437 const RuntimeShape&, const T* input2_data,
1438 const RuntimeShape& output_shape, T* output_data) {
1439 // Drop shape of second input: not needed.
1440 Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
1441 }
1442
1443 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1444 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
1445 const T* input2_data, const RuntimeShape& output_shape,
1446 T* output_data) {
1447 const int flat_size = MatchingFlatSize(input1_shape, output_shape);
1448
1449 auto max_value = input2_data[0];
1450 for (int i = 0; i < flat_size; i++) {
1451 output_data[i] = input1_data[i] < max_value ? max_value : input1_data[i];
1452 }
1453 }
1454
1455 // Convenience version that allows, for example, generated-code calls to be
1456 // the same as other binary ops.
1457 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1458 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
1459 const RuntimeShape&, const T* input2_data,
1460 const RuntimeShape& output_shape, T* output_data) {
1461 // Drop shape of second input: not needed.
1462 Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
1463 }
1464
1465 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)1466 void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
1467 const T3* input2_data, const RuntimeShape& output_shape,
1468 T2* output_data) {
1469 ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
1470 std::greater<T1>());
1471 }
1472
1473 // Convenience version that allows, for example, generated-code calls to be
1474 // the same as other binary ops.
1475 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const RuntimeShape & input2_shape,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)1476 inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
1477 const RuntimeShape& input2_shape, const T3* input2_data,
1478 const RuntimeShape& output_shape, T2* output_data) {
1479 // Drop shape of second input: not needed.
1480 ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
1481 }
1482
1483 template <typename D, typename T>
Select(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)1484 void Select(const RuntimeShape& input_condition_shape,
1485 const D* input_condition_data, const RuntimeShape& input_x_shape,
1486 const T* input_x_data, const RuntimeShape& input_y_shape,
1487 const T* input_y_data, const RuntimeShape& output_shape,
1488 T* output_data) {
1489 const int64_t flatsize = MatchingFlatSize(
1490 input_condition_shape, input_x_shape, input_y_shape, output_shape);
1491 for (int64_t i = 0; i < flatsize; ++i) {
1492 output_data[i] =
1493 input_condition_data[i] ? input_x_data[i] : input_y_data[i];
1494 }
1495 }
1496
1497 template <typename D, typename T>
RankOneSelect(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)1498 void RankOneSelect(const RuntimeShape& input_condition_shape,
1499 const D* input_condition_data,
1500 const RuntimeShape& input_x_shape, const T* input_x_data,
1501 const RuntimeShape& input_y_shape, const T* input_y_data,
1502 const RuntimeShape& output_shape, T* output_data) {
1503 const int64_t outer_size = input_condition_shape.FlatSize();
1504 int64_t inner_size;
1505 if (input_condition_shape.DimensionsCount() == 0) {
1506 inner_size = MatchingFlatSize(input_x_shape, input_y_shape, output_shape);
1507 } else {
1508 TFLITE_DCHECK_EQ(
1509 MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
1510 outer_size);
1511 inner_size =
1512 MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
1513 }
1514
1515 int64_t offset = 0;
1516 for (int64_t i = 0; i < outer_size; i++) {
1517 const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
1518 memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
1519 offset += inner_size;
1520 }
1521 }
1522
1523 template <typename D, typename T>
BroadcastSelect4DSlow(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)1524 void BroadcastSelect4DSlow(const RuntimeShape& input_condition_shape,
1525 const D* input_condition_data,
1526 const RuntimeShape& input_x_shape,
1527 const T* input_x_data,
1528 const RuntimeShape& input_y_shape,
1529 const T* input_y_data,
1530 const RuntimeShape& output_shape, T* output_data) {
1531 TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), 4);
1532 TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), 4);
1533 TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), 4);
1534 TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
1535
1536 const RuntimeShape extended_output_shape =
1537 RuntimeShape::ExtendedShape(4, output_shape);
1538
1539 NdArrayDesc<4> desc_condition;
1540 NdArrayDesc<4> desc_x;
1541 NdArrayDesc<4> desc_y;
1542 NdArrayDescsForElementwiseBroadcast(input_condition_shape, input_x_shape,
1543 input_y_shape, &desc_condition, &desc_x,
1544 &desc_y);
1545
1546 // In Tensorflow, the dimensions are canonically named (batch_number, row,
1547 // col, channel), with extents (batches, height, width, depth), with the
1548 // trailing dimension changing most rapidly (channels has the smallest
1549 // stride, typically 1 element).
1550 //
1551 // In generated C code, we store arrays with the dimensions reversed. The
1552 // first dimension has smallest stride.
1553 //
1554 // We name our variables by their Tensorflow convention, but generate C code
1555 // nesting loops such that the innermost loop has the smallest stride for
1556 // the best cache behavior.
1557 for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
1558 for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
1559 for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
1560 for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
1561 const int condition_index =
1562 SubscriptToIndex(desc_condition, b, y, x, c);
1563 const int x_index = SubscriptToIndex(desc_x, b, y, x, c);
1564 const int y_index = SubscriptToIndex(desc_y, b, y, x, c);
1565 output_data[Offset(extended_output_shape, b, y, x, c)] =
1566 input_condition_data[condition_index] ? input_x_data[x_index]
1567 : input_y_data[y_index];
1568 }
1569 }
1570 }
1571 }
1572 }
1573
1574 template <typename D, typename T>
SelectTrueCoords(const RuntimeShape & input_condition_shape,const D * input_condition_data,T * output_data)1575 void SelectTrueCoords(const RuntimeShape& input_condition_shape,
1576 const D* input_condition_data, T* output_data) {
1577 const size_t size = input_condition_shape.FlatSize();
1578 if (size == 0) {
1579 // Dimension is zero, in which case we don't need to output.
1580 return;
1581 }
1582 const size_t cond_rank = input_condition_shape.DimensionsCount();
1583
1584 std::vector<int> dims_to_count(cond_rank, 0);
1585 int cur_flat_size = size;
1586 for (int i = 0; i < cond_rank; ++i) {
1587 dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i);
1588 cur_flat_size = dims_to_count[i];
1589 }
1590
1591 int output_index = 0;
1592 for (int i = 0; i < size; ++i) {
1593 if (input_condition_data[i]) {
1594 // Insert the coordinate of the current item (row major) into output.
1595 int flat_index = i;
1596 for (int j = 0; j < cond_rank; ++j) {
1597 int coord_j = flat_index / dims_to_count[j];
1598 output_data[output_index * cond_rank + j] = coord_j;
1599 flat_index %= dims_to_count[j];
1600 }
1601 output_index++;
1602 }
1603 }
1604 }
1605
1606 // For easy implementation, the indices is always a vector of size-4 vectors.
1607 template <typename T, typename TI>
SparseToDense(const std::vector<std::vector<TI>> & indices,const T * values,T default_value,bool value_is_scalar,const RuntimeShape & unextended_output_shape,T * output_data)1608 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
1609 const T* values, T default_value,
1610 bool value_is_scalar,
1611 const RuntimeShape& unextended_output_shape,
1612 T* output_data) {
1613 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1614 const RuntimeShape output_shape =
1615 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1616 const int value_count = indices.size();
1617
1618 // First fill the output_data with default value.
1619 const int num_elements = output_shape.FlatSize();
1620 for (int i = 0; i < num_elements; ++i) {
1621 output_data[i] = default_value;
1622 }
1623
1624 // Special handle for value is scalar case to avoid checking the boolean
1625 // condition within the loop every time.
1626 if (value_is_scalar) {
1627 for (int i = 0; i < value_count; ++i) {
1628 const std::vector<TI>& index = indices[i];
1629 TFLITE_DCHECK_EQ(index.size(), 4);
1630 const T value = *values; // just use the first value.
1631 output_data[Offset(output_shape, index[0], index[1], index[2],
1632 index[3])] = value;
1633 }
1634 return;
1635 }
1636
1637 // Go through the values and indices to fill the sparse values.
1638 for (int i = 0; i < value_count; ++i) {
1639 const std::vector<TI>& index = indices[i];
1640 TFLITE_DCHECK_EQ(index.size(), 4);
1641 const T value = values[i];
1642 output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] =
1643 value;
1644 }
1645 }
1646
1647 template <typename T>
Pow(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)1648 inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
1649 const RuntimeShape& input2_shape, const T* input2_data,
1650 const RuntimeShape& output_shape, T* output_data) {
1651 const int flat_size =
1652 MatchingFlatSize(input1_shape, input2_shape, output_shape);
1653 for (int i = 0; i < flat_size; ++i) {
1654 output_data[i] = std::pow(input1_data[i], input2_data[i]);
1655 }
1656 }
1657
1658 template <typename T>
BroadcastPow4DSlow(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)1659 inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape,
1660 const T* input1_data,
1661 const RuntimeShape& unextended_input2_shape,
1662 const T* input2_data,
1663 const RuntimeShape& unextended_output_shape,
1664 T* output_data) {
1665 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
1666 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
1667 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1668 const RuntimeShape output_shape =
1669 RuntimeShape::ExtendedShape(4, unextended_output_shape);
1670
1671 NdArrayDesc<4> desc1;
1672 NdArrayDesc<4> desc2;
1673 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
1674 unextended_input2_shape, &desc1, &desc2);
1675
1676 for (int b = 0; b < output_shape.Dims(0); ++b) {
1677 for (int y = 0; y < output_shape.Dims(1); ++y) {
1678 for (int x = 0; x < output_shape.Dims(2); ++x) {
1679 for (int c = 0; c < output_shape.Dims(3); ++c) {
1680 auto out_idx = Offset(output_shape, b, y, x, c);
1681 auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
1682 auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
1683 auto in1_val = input1_data[in1_idx];
1684 auto in2_val = input2_data[in2_idx];
1685 output_data[out_idx] = std::pow(in1_val, in2_val);
1686 }
1687 }
1688 }
1689 }
1690 }
1691
1692 template <typename Scalar>
Reverse(int axis,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * output_data)1693 void Reverse(int axis, const RuntimeShape& input_shape,
1694 const Scalar* input_data, const RuntimeShape& output_shape,
1695 Scalar* output_data) {
1696 ruy::profiler::ScopeLabel label("Reverse");
1697
1698 int outer_size = 1;
1699 for (int i = 0; i < axis; ++i) {
1700 outer_size *= input_shape.Dims(i);
1701 }
1702
1703 int copy_size = 1;
1704 for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) {
1705 copy_size *= input_shape.Dims(i);
1706 }
1707
1708 const int dims_at_axis = input_shape.Dims(axis);
1709 for (int i = 0; i < outer_size; ++i) {
1710 for (int j = 0; j < dims_at_axis; ++j) {
1711 const int start_pos = (i * dims_at_axis + j) * copy_size;
1712 Scalar* output_ptr = output_data + start_pos;
1713 int loc = (i * dims_at_axis + dims_at_axis - j - 1) * copy_size;
1714 memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
1715 }
1716 }
1717 }
1718
1719 template <typename Scalar, typename TS>
ReverseSequence(const TS * seq_lengths,const int seq_dim,const int batch_dim,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * output_data)1720 void ReverseSequence(const TS* seq_lengths, const int seq_dim,
1721 const int batch_dim, const RuntimeShape& input_shape,
1722 const Scalar* input_data, const RuntimeShape& output_shape,
1723 Scalar* output_data) {
1724 ruy::profiler::ScopeLabel label("ReverseSequence");
1725
1726 int outer_size = 1;
1727 int outer_dim = std::min(batch_dim, seq_dim);
1728 int medium_dim = std::max(batch_dim, seq_dim);
1729 for (int i = 0; i < outer_dim; ++i) {
1730 outer_size *= input_shape.Dims(i);
1731 }
1732
1733 int medium_size = 1;
1734 for (int i = outer_dim + 1; i < medium_dim; ++i) {
1735 medium_size *= input_shape.Dims(i);
1736 }
1737
1738 int copy_size = 1;
1739 for (int i = medium_dim + 1; i < input_shape.DimensionsCount(); ++i) {
1740 copy_size *= input_shape.Dims(i);
1741 }
1742
1743 const int dims_at_outer_dim = input_shape.Dims(outer_dim);
1744 const int dims_at_medium_dim = input_shape.Dims(medium_dim);
1745
1746 Scalar* output_ptr;
1747 if (batch_dim > seq_dim) {
1748 for (int i = 0; i < outer_size; ++i) {
1749 for (int j = 0; j < dims_at_outer_dim; ++j) {
1750 const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1751 for (int p = 0; p < medium_size; ++p) {
1752 for (int q = 0; q < dims_at_medium_dim; ++q) {
1753 const int in_pos =
1754 ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1755 const Scalar* in_ptr = input_data + in_pos;
1756 int sl = seq_lengths[q] - 1;
1757 if (j > sl) {
1758 output_ptr = output_data + in_pos;
1759 } else {
1760 const int out_pos_base =
1761 (i * dims_at_outer_dim + sl - j) * medium_size;
1762 const int out_pos =
1763 ((out_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1764 output_ptr = output_data + out_pos;
1765 }
1766 memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
1767 }
1768 }
1769 }
1770 }
1771 } else if (batch_dim < seq_dim) {
1772 for (int i = 0; i < outer_size; ++i) {
1773 for (int j = 0; j < dims_at_outer_dim; ++j) {
1774 const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1775 int sl = seq_lengths[j] - 1;
1776 const int out_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1777 for (int p = 0; p < medium_size; ++p) {
1778 for (int q = 0; q < dims_at_medium_dim; ++q) {
1779 const int in_pos =
1780 ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1781 const Scalar* in_ptr = input_data + in_pos;
1782 if (q > sl) {
1783 output_ptr = output_data + in_pos;
1784 } else {
1785 const int out_pos =
1786 ((out_pos_base + p) * dims_at_medium_dim + sl - q) *
1787 copy_size;
1788 output_ptr = output_data + out_pos;
1789 }
1790 memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
1791 }
1792 }
1793 }
1794 }
1795 }
1796 }
1797
1798 template <typename T>
SegmentSum(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & segment_ids_shape,const int32_t * segment_ids_data,const RuntimeShape & output_shape,T * output_data)1799 inline void SegmentSum(const RuntimeShape& input_shape, const T* input_data,
1800 const RuntimeShape& segment_ids_shape,
1801 const int32_t* segment_ids_data,
1802 const RuntimeShape& output_shape, T* output_data) {
1803 const int segment_flat_size =
1804 MatchingFlatSizeSkipDim(input_shape, 0, output_shape);
1805
1806 memset(output_data, 0, sizeof(T) * output_shape.FlatSize());
1807
1808 for (int i = 0; i < input_shape.Dims(0); i++) {
1809 int output_index = segment_ids_data[i];
1810 for (int j = 0; j < segment_flat_size; ++j) {
1811 output_data[output_index * segment_flat_size + j] +=
1812 input_data[i * segment_flat_size + j];
1813 }
1814 }
1815 }
1816
1817 } // namespace reference_ops
1818 } // namespace tflite
1819
1820 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
1821