• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stddef.h>
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <cstdint>
20 #include <functional>
21 #include <limits>
22 
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/kernels/cpu_backend_context.h"
26 #include "tensorflow/lite/kernels/internal/common.h"
27 #include "tensorflow/lite/kernels/internal/compatibility.h"
28 #include "tensorflow/lite/kernels/internal/cppmath.h"
29 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
30 #include "tensorflow/lite/kernels/internal/quantization_util.h"
31 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
32 #include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h"
33 #include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
34 #include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
35 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
36 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
37 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
38 #include "tensorflow/lite/kernels/internal/reference/softmax.h"
39 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
40 #include "tensorflow/lite/kernels/internal/tensor.h"
41 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
42 #include "tensorflow/lite/kernels/internal/types.h"
43 #include "tensorflow/lite/kernels/kernel_util.h"
44 
45 #if __aarch64__ && __clang__
46 #include <arm_neon.h>
47 #endif
48 
49 namespace tflite {
50 namespace ops {
51 namespace builtin {
52 namespace activations {
53 
54 // TODO(b/142762739): We should figure out a multi-threading plan for most of
55 // the activation ops below.
56 
57 enum KernelType {
58   kReference,
59   kGenericOptimized,
60   kFixedPointOptimized,
61 };
62 
63 struct OpData {
64   int32_t input_multiplier = 0;
65   int input_left_shift = 0;
66   int32_t input_range_radius = 0;
67   int diff_min = 0;
68   uint8_t table[256] = {0};
69 };
70 
71 struct SoftmaxOpData {
72   struct SoftmaxParams params = {};
73   float table[256];
74 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
75   uint8_t uint8_table1[256];
76   uint8_t uint8_table2[256];
77 #endif
78   static constexpr int kInt16LUTArraySize = 513;
79   int16_t exp_lut[kInt16LUTArraySize];  // int16 LUT for exp(x), where x uniform
80                                         // distributed between [-10.0 , 0.0]
81   int16_t one_over_one_plus_x_lut[kInt16LUTArraySize];  // int16 LUT for 1 /
82                                                         // (1 + x), where x
83                                                         // uniform distributed
84                                                         // between [0.0 , 1.0]
85 };
86 
87 struct LogSoftmaxOpData : public OpData {
88   int32_t reverse_scaling_divisor = 0;
89   int32_t reverse_scaling_right_shift = 0;
90   struct SoftmaxParams params = {};
91   float f_table[256];
92 };
93 
94 struct LeakyReluOpData : public OpData {
95   int32_t output_multiplier_alpha = 0;
96   int32_t output_shift_alpha = 0;
97   int32_t output_multiplier_identity = 0;
98   int32_t output_shift_identity = 0;
99 };
100 
101 struct PreluOpData : public OpData {
102   int32_t output_multiplier_1 = 0;
103   int32_t output_shift_1 = 0;
104   int32_t output_multiplier_2 = 0;
105   int32_t output_shift_2 = 0;
106   bool requires_broadcast;
107 };
108 
109 struct HardSwishData {
110   HardSwishParams params;
111 };
112 
113 struct ReluOpData : public OpData {
114   int32_t output_multiplier = 0;
115   int output_shift = 0;
116 };
117 
118 namespace {
CheckOutputQuantParams(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * output)119 TfLiteStatus CheckOutputQuantParams(TfLiteContext* context,
120                                     const TfLiteTensor* input,
121                                     const TfLiteTensor* output) {
122   TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
123   if (input->type == kTfLiteUInt8) {
124     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
125   } else {
126     TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
127   }
128   return kTfLiteOk;
129 }
130 
131 template <typename T>
PopulateLookupTable(struct OpData * data,const TfLiteTensor * input,TfLiteTensor * output,const std::function<float (float)> & transform)132 void PopulateLookupTable(struct OpData* data, const TfLiteTensor* input,
133                          TfLiteTensor* output,
134                          const std::function<float(float)>& transform) {
135   static_assert(sizeof(T) == 1, "Lookup table valid only for 8bit");
136   const float inverse_scale = 1 / output->params.scale;
137   int32_t maxval = std::numeric_limits<T>::max();
138   int32_t minval = std::numeric_limits<T>::min();
139   for (int32_t val = minval; val <= maxval; ++val) {
140     const float dequantized =
141         input->params.scale * (val - input->params.zero_point);
142     const float transformed = transform(dequantized);
143     const float rescaled = std::round(transformed * inverse_scale);
144     const int32_t quantized =
145         static_cast<int32_t>(rescaled + output->params.zero_point);
146     data->table[static_cast<uint8_t>(static_cast<T>(val))] =
147         static_cast<uint8_t>(
148             static_cast<T>(std::max(std::min(maxval, quantized), minval)));
149   }
150 }
151 
152 // TODO(b/143696793): move this to optimized_ops.
EvalUsingLookupTable(struct OpData * data,const TfLiteTensor * input,TfLiteTensor * output)153 void EvalUsingLookupTable(struct OpData* data, const TfLiteTensor* input,
154                           TfLiteTensor* output) {
155   const int size =
156       MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
157   uint8_t* output_data = GetTensorData<uint8_t>(output);
158   const uint8_t* input_data = GetTensorData<uint8_t>(input);
159   int i = 0;
160 #if __aarch64__ && __clang__
161   // This code uses ARM64-only instructions.
162   // TODO(b/143709993): Port to ARMv7
163 
164   // Load the tables into registers. (4*4 128-bit registers)
165   uint8x16x4_t table[4];
166   table[0] = vld1q_u8_x4(data->table + 16 * 4 * 0);
167   table[1] = vld1q_u8_x4(data->table + 16 * 4 * 1);
168   table[2] = vld1q_u8_x4(data->table + 16 * 4 * 2);
169   table[3] = vld1q_u8_x4(data->table + 16 * 4 * 3);
170 
171   // Vectorized loop; process uint8x16_t (16 elements) at a time.
172   constexpr int vectorized_16_loop_step = 16;
173   const int vectorized_16_loop_end =
174       size / vectorized_16_loop_step * vectorized_16_loop_step;
175   for (; i < vectorized_16_loop_end; i += vectorized_16_loop_step) {
176     uint8x16_t input = vld1q_u8(input_data + i);
177     uint8x16_t output = optimized_ops::aarch64_lookup_vector(table, input);
178     vst1q_u8(output_data + i, output);
179   }
180   // Postamble and non-ARM64 code: simple for loop.
181 #endif
182   for (; i < size; ++i) {
183     output_data[i] = data->table[input_data[i]];
184   }
185 }
186 
187 template <typename T>
QuantizedReluX(float act_min,float act_max,const TfLiteTensor * input,TfLiteTensor * output,const ReluOpData * data)188 void QuantizedReluX(float act_min, float act_max, const TfLiteTensor* input,
189                     TfLiteTensor* output, const ReluOpData* data) {
190   ReluParams params;
191   params.quantized_activation_min =
192       std::max(static_cast<int32_t>(std::numeric_limits<T>::min()),
193                output->params.zero_point +
194                    static_cast<int32>(roundf(act_min / output->params.scale)));
195   params.quantized_activation_max =
196       act_max == std::numeric_limits<float>::infinity()
197           ? static_cast<int32_t>(std::numeric_limits<T>::max())
198           : std::min(
199                 static_cast<int32_t>(std::numeric_limits<T>::max()),
200                 output->params.zero_point +
201                     static_cast<int32>(roundf(act_max / output->params.scale)));
202   params.input_offset = input->params.zero_point;
203   params.output_offset = output->params.zero_point;
204   params.output_multiplier = data->output_multiplier;
205   params.output_shift = data->output_shift;
206   optimized_ops::ReluX(params, GetTensorShape(input), GetTensorData<T>(input),
207                        GetTensorShape(output), GetTensorData<T>(output));
208 }
209 
210 }  // namespace
211 
Init(TfLiteContext * context,const char * buffer,size_t length)212 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
213   // This is a builtin op, so we don't use the contents in 'buffer', if any.
214   // Instead, we allocate a new object to carry information from Prepare() to
215   // Eval().
216   return new OpData;
217 }
218 
SoftmaxInit(TfLiteContext * context,const char * buffer,size_t length)219 void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
220   return new SoftmaxOpData;
221 }
222 
SoftmaxFree(TfLiteContext * context,void * buffer)223 void SoftmaxFree(TfLiteContext* context, void* buffer) {
224   delete reinterpret_cast<SoftmaxOpData*>(buffer);
225 }
226 
LogSoftmaxInit(TfLiteContext * context,const char * buffer,size_t length)227 void* LogSoftmaxInit(TfLiteContext* context, const char* buffer,
228                      size_t length) {
229   return new LogSoftmaxOpData;
230 }
231 
PreluInit(TfLiteContext * context,const char * buffer,size_t length)232 void* PreluInit(TfLiteContext* context, const char* buffer, size_t length) {
233   return new PreluOpData;
234 }
235 
Free(TfLiteContext * context,void * buffer)236 void Free(TfLiteContext* context, void* buffer) {
237   delete reinterpret_cast<OpData*>(buffer);
238 }
239 
LogSoftmaxFree(TfLiteContext * context,void * buffer)240 void LogSoftmaxFree(TfLiteContext* context, void* buffer) {
241   delete reinterpret_cast<LogSoftmaxOpData*>(buffer);
242 }
243 
PreluFree(TfLiteContext * context,void * buffer)244 void PreluFree(TfLiteContext* context, void* buffer) {
245   delete reinterpret_cast<PreluOpData*>(buffer);
246 }
247 
HardSwishInit(TfLiteContext * context,const char * buffer,size_t length)248 void* HardSwishInit(TfLiteContext* context, const char* buffer, size_t length) {
249   return new HardSwishData;
250 }
251 
GenericPrepare(TfLiteContext * context,TfLiteNode * node)252 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
253   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
254   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
255   const TfLiteTensor* input;
256   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
257   TfLiteTensor* output;
258   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
259   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
260 
261   return context->ResizeTensor(context, output,
262                                TfLiteIntArrayCopy(input->dims));
263 }
264 
ReluInit(TfLiteContext * context,const char * buffer,size_t length)265 void* ReluInit(TfLiteContext* context, const char* buffer, size_t length) {
266   return new ReluOpData;
267 }
268 
ReluFree(TfLiteContext * context,void * buffer)269 void ReluFree(TfLiteContext* context, void* buffer) {
270   delete reinterpret_cast<ReluOpData*>(buffer);
271 }
272 
ReluPrepare(TfLiteContext * context,TfLiteNode * node)273 TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
274   ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
275   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
276   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
277   const TfLiteTensor* input;
278   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
279   TfLiteTensor* output;
280   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
281   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
282 
283   if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8 ||
284       input->type == kTfLiteInt16) {
285     double real_multiplier = input->params.scale / output->params.scale;
286     QuantizeMultiplier(real_multiplier, &data->output_multiplier,
287                        &data->output_shift);
288   }
289 
290   if (input->type == kTfLiteInt16) {
291     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
292     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
293   }
294 
295   return context->ResizeTensor(context, output,
296                                TfLiteIntArrayCopy(input->dims));
297 }
298 
LeakyReluInit(TfLiteContext * context,const char * buffer,size_t length)299 void* LeakyReluInit(TfLiteContext* context, const char* buffer, size_t length) {
300   return new LeakyReluOpData;
301 }
302 
LeakyReluFree(TfLiteContext * context,void * buffer)303 void LeakyReluFree(TfLiteContext* context, void* buffer) {
304   delete reinterpret_cast<LeakyReluOpData*>(buffer);
305 }
306 
HardSwishFree(TfLiteContext * context,void * buffer)307 void HardSwishFree(TfLiteContext* context, void* buffer) {
308   delete static_cast<HardSwishData*>(buffer);
309 }
310 
HardSwishPrepare(TfLiteContext * context,TfLiteNode * node)311 TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
312   TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
313   TfLiteTensor* output;
314   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
315 
316   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
317     HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
318     HardSwishParams* params = &data->params;
319     const TfLiteTensor* input;
320     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
321     params->input_zero_point = input->params.zero_point;
322     params->output_zero_point = output->params.zero_point;
323     const float input_scale = input->params.scale;
324     const float hires_input_scale = (1.0f / 128.0f) * input_scale;
325     const float reluish_scale = 3.0f / 32768.0f;
326     const float output_scale = output->params.scale;
327 
328     const float output_multiplier = hires_input_scale / output_scale;
329 
330     int32_t output_multiplier_fixedpoint_int32;
331     QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32,
332                        &params->output_multiplier_exponent);
333     DownScaleInt32ToInt16Multiplier(
334         output_multiplier_fixedpoint_int32,
335         &params->output_multiplier_fixedpoint_int16);
336     TF_LITE_ENSURE(context, params->output_multiplier_exponent <= 0);
337 
338     const float reluish_multiplier = hires_input_scale / reluish_scale;
339     int32_t reluish_multiplier_fixedpoint_int32;
340     QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32,
341                        &params->reluish_multiplier_exponent);
342     DownScaleInt32ToInt16Multiplier(
343         reluish_multiplier_fixedpoint_int32,
344         &params->reluish_multiplier_fixedpoint_int16);
345   }
346   return kTfLiteOk;
347 }
348 
LeakyReluPrepare(TfLiteContext * context,TfLiteNode * node)349 TfLiteStatus LeakyReluPrepare(TfLiteContext* context, TfLiteNode* node) {
350   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
351   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
352   const TfLiteTensor* input;
353   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
354   TfLiteTensor* output;
355   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
356   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
357 
358   LeakyReluOpData* data = reinterpret_cast<LeakyReluOpData*>(node->user_data);
359 
360   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
361       output->type == kTfLiteInt16) {
362     const auto* params =
363         reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data);
364 
365     double alpha_multiplier =
366         input->params.scale * params->alpha / output->params.scale;
367     QuantizeMultiplier(alpha_multiplier, &data->output_multiplier_alpha,
368                        &data->output_shift_alpha);
369     double identity_multiplier = input->params.scale / output->params.scale;
370     QuantizeMultiplier(identity_multiplier, &data->output_multiplier_identity,
371                        &data->output_shift_identity);
372   }
373 
374   if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
375     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
376     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
377   }
378 
379   return context->ResizeTensor(context, output,
380                                TfLiteIntArrayCopy(input->dims));
381 }
382 
383 template <KernelType kernel_type>
TanhPrepare(TfLiteContext * context,TfLiteNode * node)384 TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
385   OpData* data = reinterpret_cast<OpData*>(node->user_data);
386 
387   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
388   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
389   const TfLiteTensor* input;
390   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
391   TfLiteTensor* output;
392   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
393   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
394 
395   if (kernel_type == kFixedPointOptimized) {
396     if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
397       static constexpr int kInputIntegerBits = 4;
398 
399       const double input_real_multiplier =
400           input->params.scale *
401           static_cast<double>(1 << (15 - kInputIntegerBits));
402 
403       const double q =
404           std::frexp(input_real_multiplier, &data->input_left_shift);
405       auto q_fixed = static_cast<int32_t>(TfLiteRound(q * (1ll << 15)));
406       data->input_multiplier = static_cast<int16_t>(q_fixed);
407 
408       int16_t input_range_radius =
409           CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 15);
410       data->input_range_radius = input_range_radius;
411     }
412   }
413 
414   if (kernel_type == kGenericOptimized || kernel_type == kReference) {
415     if (input->type == kTfLiteUInt8) {
416       PopulateLookupTable<uint8_t>(
417           data, input, output, [](float value) { return std::tanh(value); });
418     } else if (input->type == kTfLiteInt8) {
419       PopulateLookupTable<int8_t>(data, input, output,
420                                   [](float value) { return std::tanh(value); });
421     }
422   }
423 
424   if (input->type == kTfLiteInt16) {
425     static constexpr int kInputIntegerBits = 3;
426     static constexpr int kOutputFractionalBits = 15;
427 
428     // These operators are implemented in fixed-point arithmetic,
429     // which intrinsically wants symmetric ranges (zero_point==0)
430     // and power-of-two scales (power-of-two is abbreviated below as POT).
431     // While more general support would be possible by means of rescaling,
432     // that would add some overhead and some loss of accuracy and wouldn't
433     // be used at the moment as current quantized LSTM applications are
434     // happy with symmetric, power-of-two-scales quantization. So we just
435     // implement that narrow case only for now.
436 
437     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
438     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
439 
440     int input_scale_log2_rounded;
441     bool param_scale_pot =
442         CheckedLog2(input->params.scale, &input_scale_log2_rounded);
443 
444     data->input_left_shift =
445         (15 - kInputIntegerBits) + input_scale_log2_rounded;
446     param_scale_pot &=
447         (data->input_left_shift == 0 || data->input_left_shift == 1);
448 
449     if (!param_scale_pot) {
450       // Calculate multiplier to change input scale to 1/(3*4096)
451       // as required by the table lookup.
452       // The number 3.0 in the multiplier comes from here,
453       // because the interval is [-10.7, 10.7] instead of [-8, 8].
454       // So, in this scaling +/-2^17 represents +/-10.7.
455 
456       double multiplier = input->params.scale * 4096.0 * 3.0;
457       data->input_left_shift = 0;
458 
459       while (multiplier <= 32767.0 / 2.0 && data->input_left_shift <= 30) {
460         data->input_left_shift++;
461         multiplier = multiplier * 2.0;
462       }
463 
464       data->input_multiplier = static_cast<int32_t>(multiplier);
465     }
466 
467     int output_scale_log2_rounded;
468     TF_LITE_ENSURE(
469         context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
470     TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
471                       -kOutputFractionalBits);
472   }
473 
474   return context->ResizeTensor(context, output,
475                                TfLiteIntArrayCopy(input->dims));
476 }
477 
478 template <KernelType kernel_type>
SigmoidPrepare(TfLiteContext * context,TfLiteNode * node)479 TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
480   OpData* data = reinterpret_cast<OpData*>(node->user_data);
481 
482   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
483   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
484   const TfLiteTensor* input;
485   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
486   TfLiteTensor* output;
487   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
488   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
489 
490   if (kernel_type == kFixedPointOptimized) {
491     if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
492       if (input->type == kTfLiteUInt8) {
493         TF_LITE_ENSURE_EQ(context, output->params.zero_point,
494                           std::numeric_limits<uint8_t>::min());
495       }
496       if (input->type == kTfLiteInt8) {
497         TF_LITE_ENSURE_EQ(context, output->params.zero_point,
498                           std::numeric_limits<int8_t>::min());
499       }
500       TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
501 
502       static constexpr int kInputIntegerBits = 4;
503 
504       const double input_real_multiplier =
505           input->params.scale *
506           static_cast<double>(1 << (15 - kInputIntegerBits));
507 
508       const double q =
509           std::frexp(input_real_multiplier, &data->input_left_shift);
510       auto q_fixed = static_cast<int32_t>(TfLiteRound(q * (1ll << 15)));
511       data->input_multiplier = static_cast<int16_t>(q_fixed);
512 
513       int16_t input_range_radius =
514           CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 15);
515       data->input_range_radius = input_range_radius;
516     }
517   }
518 
519   if (kernel_type == kGenericOptimized || kernel_type == kReference) {
520     if (input->type == kTfLiteUInt8) {
521       TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
522       PopulateLookupTable<uint8_t>(data, input, output, [](float value) {
523         return 1.0f / (1.0f + std::exp(-value));
524       });
525     } else if (input->type == kTfLiteInt8) {
526       TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
527       PopulateLookupTable<int8_t>(data, input, output, [](float value) {
528         return 1.0f / (1.0f + std::exp(-value));
529       });
530     } else if (input->type == kTfLiteInt16) {
531       TF_LITE_ENSURE(context, output->params.scale == 1. / 32768);
532       TF_LITE_ENSURE(context, output->params.zero_point == 0);
533     }
534   }
535 
536   if (input->type == kTfLiteInt16) {
537     static constexpr int kInputIntegerBits = 3;
538     static constexpr int kOutputFractionalBits = 15;
539 
540     // See comments in TanhPrepare about requiring zero_point==0
541     // and a power-of-two ("POT") scale.
542 
543     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
544     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
545 
546     int input_scale_log2_rounded;
547     bool param_scale_pot =
548         CheckedLog2(input->params.scale, &input_scale_log2_rounded);
549 
550     data->input_left_shift =
551         (15 - kInputIntegerBits) + input_scale_log2_rounded;
552     param_scale_pot &= (data->input_left_shift == 0);
553 
554     if (!param_scale_pot) {
555       // Calculate multiplier to change input scale to 1/(3*4096)
556       // as required by the table lookup.
557       // In this scaling +/-2^17 represents +/-10.7
558       double multiplier = input->params.scale * 4096.0 * 3.0;
559 
560       data->input_left_shift = 0;
561 
562       while (multiplier <= 32767.0 / 2.0 && data->input_left_shift <= 30) {
563         data->input_left_shift++;
564         multiplier = multiplier * 2.0;
565       }
566 
567       data->input_multiplier = static_cast<int32_t>(multiplier);
568     }
569 
570     int output_scale_log2_rounded;
571     TF_LITE_ENSURE(
572         context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
573     TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
574                       -kOutputFractionalBits);
575   }
576 
577   return context->ResizeTensor(context, output,
578                                TfLiteIntArrayCopy(input->dims));
579 }
580 
SoftmaxPrepare(TfLiteContext * context,TfLiteNode * node)581 TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
582   auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
583   SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
584 
585   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
586   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
587   const TfLiteTensor* input;
588   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
589   TfLiteTensor* output;
590   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
591   if (output->type == kTfLiteInt16) {
592     TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
593                                 input->type == kTfLiteUInt8 ||
594                                 input->type == kTfLiteInt16);
595   } else {
596     TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
597   }
598 
599   TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
600 
601   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
602     switch (output->type) {
603       case kTfLiteUInt8:
604       case kTfLiteInt8:
605 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
606         // Only apply when both input & output are uint8/int8 & build with clang
607         // on aarch64.
608         // TODO(b/143709993): Port to ARMv7 and other platforms.
609         data->params.uint8_table1 = data->uint8_table1;
610         data->params.uint8_table2 = data->uint8_table2;
611         optimized_ops::PopulateSoftmaxUInt8LookupTable(
612             &data->params, input->params.scale, params->beta);
613         break;
614 #endif
615       case kTfLiteInt16:
616       default:
617         data->params.table = data->table;
618         optimized_ops::PopulateSoftmaxLookupTable(
619             &data->params, input->params.scale, params->beta);
620     }
621 
622     data->params.zero_point = output->params.zero_point;
623     data->params.scale = output->params.scale;
624   }
625 
626   if (input->type == kTfLiteInt16) {
627     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
628     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
629 
630     data->params.exp_lut = data->exp_lut;
631     // exp LUT only used on nagative values
632     // we consider exp(-10.0) is insignificant to accumulation
633     gen_lut([](double value) { return std::exp(value); }, -10.0, 0.0,
634             data->params.exp_lut, data->kInt16LUTArraySize);
635     data->params.one_over_one_plus_x_lut = data->one_over_one_plus_x_lut;
636     gen_lut([](double value) { return 1.0 / (1.0 + value); }, 0.0, 1.0,
637             data->params.one_over_one_plus_x_lut, data->kInt16LUTArraySize);
638     data->params.zero_point = output->params.zero_point;
639     data->params.scale = output->params.scale;
640 
641     double input_scale_beta_rescale =
642         input->params.scale * params->beta /
643         (10.0 / 65535.0);  // scale the input_diff such that [-65535, 0]
644                            // correspond to [-10.0, 0.0]
645     QuantizeMultiplier(input_scale_beta_rescale, &data->params.input_multiplier,
646                        &data->params.input_left_shift);
647   }
648 
649   return context->ResizeTensor(context, output,
650                                TfLiteIntArrayCopy(input->dims));
651 }
652 
LogSoftmaxPrepare(TfLiteContext * context,TfLiteNode * node)653 TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
654   LogSoftmaxOpData* data = reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
655 
656   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
657   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
658   const TfLiteTensor* input;
659   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
660   TfLiteTensor* output;
661   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
662   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
663 
664   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
665     TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256);
666     static const double kBeta = 1.0;
667     if (input->type == kTfLiteUInt8) {
668       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255);
669     }
670     if (input->type == kTfLiteInt8) {
671       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 127);
672     }
673     data->params.table = data->f_table;
674     optimized_ops::PopulateSoftmaxLookupTable(&data->params,
675                                               input->params.scale, kBeta);
676     data->params.zero_point = output->params.zero_point;
677     data->params.scale = output->params.scale;
678   }
679 
680   return context->ResizeTensor(context, output,
681                                TfLiteIntArrayCopy(input->dims));
682 }
683 
PreluPrepare(TfLiteContext * context,TfLiteNode * node)684 TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
685   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
686   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
687   const TfLiteTensor* input;
688   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
689   TfLiteTensor* output;
690   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
691   const TfLiteTensor* alpha;
692   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha));
693   PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
694 
695   TF_LITE_ENSURE_TYPES_EQ(context, input->type, alpha->type);
696 
697   output->type = input->type;
698 
699   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
700     // prelu(x) = x if x >= 0 else x * alpha.
701     // So if we translate that for quantized computation:
702     //
703     // input_float = (input_q - input_zp) * input_scale
704     // output_float = (output_q - output_zp) * output_scale
705     // alpha_float = (alpha_q - alpha_zp) * alpha_scale
706     //
707     // When input_q - input_zp >= 0:
708     // ouput_q = (input_q - input_zp) * input_scale / output_scale + output_q
709     // else:
710     // output_q = (input_q - input_zp) * (alpha_q - alpha_zp) * input_scale
711     //            * alpha_scale / output_scale + output_q
712     //
713     // So for input_q - input_zp >= 0:
714     // output real multiplier 1 is input_scale / output_scale;
715     // for input_q - input_zp < 0:
716     // output real multiplier 2 is input_scale  * alpha_scale/ output_scale.
717     double real_multiplier_1 = input->params.scale / output->params.scale;
718     double real_multiplier_2 =
719         input->params.scale * alpha->params.scale / output->params.scale;
720     QuantizeMultiplier(real_multiplier_1, &data->output_multiplier_1,
721                        &data->output_shift_1);
722     QuantizeMultiplier(real_multiplier_2, &data->output_multiplier_2,
723                        &data->output_shift_2);
724   }
725 
726   data->requires_broadcast = !HaveSameShapes(input, alpha);
727   // PRelu (parameteric Relu) shares the same alpha value on "shared axis".
728   // This means it's always required to "broadcast" alpha values in PRelu.
729   TfLiteIntArray* output_size = nullptr;
730   TF_LITE_ENSURE_OK(
731       context, CalculateShapeForBroadcast(context, input, alpha, &output_size));
732 
733   TF_LITE_ENSURE_OK(context,
734                     context->ResizeTensor(context, output, output_size));
735   // After broadcasting, the output shape should always be the same as the
736   // input shape.
737   TF_LITE_ENSURE(context, HaveSameShapes(input, output));
738 
739   return kTfLiteOk;
740 }
741 
ReluEval(TfLiteContext * context,TfLiteNode * node)742 TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
743   const TfLiteTensor* input;
744   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
745   TfLiteTensor* output;
746   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
747   const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
748   switch (input->type) {
749     case kTfLiteFloat32: {
750       optimized_ops::Relu(GetTensorShape(input), GetTensorData<float>(input),
751                           GetTensorShape(output), GetTensorData<float>(output));
752     } break;
753     // TODO(renjieliu): We may revisit the quantization calculation logic,
754     // the unbounded upper limit is actually hard to quantize.
755     case kTfLiteUInt8: {
756       QuantizedReluX<uint8_t>(0.0f, std::numeric_limits<float>::infinity(),
757                               input, output, data);
758     } break;
759     case kTfLiteInt8: {
760       QuantizedReluX<int8_t>(0.0f, std::numeric_limits<float>::infinity(),
761                              input, output, data);
762     } break;
763     case kTfLiteInt16: {
764       QuantizedReluX<int16_t>(0.0f, std::numeric_limits<float>::infinity(),
765                               input, output, data);
766     } break;
767     default:
768       TF_LITE_KERNEL_LOG(context,
769                          "Only float32, uint8, int8 and int16 are supported "
770                          "currently, got %s.",
771                          TfLiteTypeGetName(input->type));
772       return kTfLiteError;
773   }
774   return kTfLiteOk;
775 }
776 
Relu1Eval(TfLiteContext * context,TfLiteNode * node)777 TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
778   const TfLiteTensor* input;
779   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
780   TfLiteTensor* output;
781   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
782   const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
783   switch (input->type) {
784     case kTfLiteFloat32: {
785       optimized_ops::Relu1(GetTensorShape(input), GetTensorData<float>(input),
786                            GetTensorShape(output),
787                            GetTensorData<float>(output));
788       return kTfLiteOk;
789     } break;
790     case kTfLiteUInt8: {
791       QuantizedReluX<uint8_t>(-1.0f, 1.0f, input, output, data);
792       return kTfLiteOk;
793     } break;
794     case kTfLiteInt8: {
795       QuantizedReluX<int8_t>(-1, 1, input, output, data);
796       return kTfLiteOk;
797     } break;
798     default:
799       TF_LITE_KERNEL_LOG(context,
800                          "Only float32, uint8, int8 supported "
801                          "currently, got %s.",
802                          TfLiteTypeGetName(input->type));
803       return kTfLiteError;
804   }
805 }
806 
807 template <KernelType kernel_type>
HardSwishEval(TfLiteContext * context,TfLiteNode * node)808 TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
809   HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
810 
811   const TfLiteTensor* input;
812   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
813   TfLiteTensor* output;
814   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
815   switch (input->type) {
816     case kTfLiteFloat32: {
817       if (kernel_type == kReference) {
818         reference_ops::HardSwish(
819             GetTensorShape(input), GetTensorData<float>(input),
820             GetTensorShape(output), GetTensorData<float>(output));
821       } else {
822         optimized_ops::HardSwish(
823             GetTensorShape(input), GetTensorData<float>(input),
824             GetTensorShape(output), GetTensorData<float>(output));
825       }
826       return kTfLiteOk;
827     } break;
828     case kTfLiteUInt8: {
829       HardSwishParams& params = data->params;
830       if (kernel_type == kReference) {
831         reference_ops::HardSwish(
832             params, GetTensorShape(input), GetTensorData<uint8_t>(input),
833             GetTensorShape(output), GetTensorData<uint8_t>(output));
834       } else {
835         optimized_ops::HardSwish(
836             params, GetTensorShape(input), GetTensorData<uint8_t>(input),
837             GetTensorShape(output), GetTensorData<uint8_t>(output));
838       }
839       return kTfLiteOk;
840     } break;
841     case kTfLiteInt8: {
842       HardSwishParams& params = data->params;
843       if (kernel_type == kReference) {
844         reference_ops::HardSwish(
845             params, GetTensorShape(input), GetTensorData<int8_t>(input),
846             GetTensorShape(output), GetTensorData<int8_t>(output));
847       } else {
848         optimized_ops::HardSwish(
849             params, GetTensorShape(input), GetTensorData<int8_t>(input),
850             GetTensorShape(output), GetTensorData<int8_t>(output));
851       }
852       return kTfLiteOk;
853     } break;
854     default:
855       TF_LITE_KERNEL_LOG(
856           context,
857           "Only float32, uint8 and int8 are supported currently, got %s.",
858           TfLiteTypeGetName(input->type));
859       return kTfLiteError;
860   }
861 }
862 
Relu6Eval(TfLiteContext * context,TfLiteNode * node)863 TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
864   const TfLiteTensor* input;
865   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
866   TfLiteTensor* output;
867   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
868   ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
869   switch (input->type) {
870     case kTfLiteFloat32: {
871       size_t elements = input->bytes / sizeof(float);
872       const float* in = GetTensorData<float>(input);
873       const float* in_end = in + elements;
874       float* out = GetTensorData<float>(output);
875       for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f);
876       return kTfLiteOk;
877     } break;
878     case kTfLiteUInt8:
879       QuantizedReluX<uint8_t>(0.0f, 6.0f, input, output, data);
880       return kTfLiteOk;
881     case kTfLiteInt8: {
882       QuantizedReluX<int8_t>(0.0f, 6.0f, input, output, data);
883       return kTfLiteOk;
884     } break;
885     case kTfLiteInt16: {
886       QuantizedReluX<int16_t>(0.0f, 6.0f, input, output, data);
887       return kTfLiteOk;
888     } break;
889     default:
890       TF_LITE_KERNEL_LOG(context,
891                          "Only float32, uint8, int8 and int16 are supported "
892                          "currently, got %s.",
893                          TfLiteTypeGetName(input->type));
894       return kTfLiteError;
895   }
896 }
897 
898 template <KernelType kernel_type>
TanhEval(TfLiteContext * context,TfLiteNode * node)899 TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
900   OpData* data = reinterpret_cast<OpData*>(node->user_data);
901   const TfLiteTensor* input;
902   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
903   TfLiteTensor* output;
904   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
905   switch (input->type) {
906     case kTfLiteFloat32: {
907       if (kernel_type == kReference) {
908         reference_ops::Tanh(GetTensorShape(input), GetTensorData<float>(input),
909                             GetTensorShape(output),
910                             GetTensorData<float>(output));
911       } else {
912         optimized_ops::Tanh(GetTensorShape(input), GetTensorData<float>(input),
913                             GetTensorShape(output),
914                             GetTensorData<float>(output));
915       }
916       return kTfLiteOk;
917     } break;
918     case kTfLiteInt16: {
919       TanhParams params;
920       params.input_left_shift = data->input_left_shift;
921       if (kernel_type == kReference || (data->input_multiplier > 0)) {
922         reference_integer_ops::Tanh(
923             data->input_multiplier, data->input_left_shift,
924             GetTensorShape(input), GetTensorData<int16_t>(input),
925             GetTensorShape(output), GetTensorData<int16_t>(output));
926       } else {
927         optimized_ops::Tanh(
928             params, GetTensorShape(input), GetTensorData<int16_t>(input),
929             GetTensorShape(output), GetTensorData<int16_t>(output));
930       }
931       return kTfLiteOk;
932     } break;
933     case kTfLiteUInt8: {
934       if (kernel_type == kFixedPointOptimized) {
935         TanhParams params;
936         params.input_zero_point = input->params.zero_point;
937         params.input_range_radius = data->input_range_radius;
938         params.input_multiplier = data->input_multiplier;
939         params.input_left_shift = data->input_left_shift;
940         optimized_ops::Tanh16bitPrecision(
941             params, GetTensorShape(input), GetTensorData<uint8_t>(input),
942             GetTensorShape(output), GetTensorData<uint8_t>(output));
943       } else {
944         EvalUsingLookupTable(data, input, output);
945       }
946       return kTfLiteOk;
947     } break;
948     case kTfLiteInt8: {
949       if (kernel_type == kFixedPointOptimized) {
950         TanhParams params;
951         params.input_zero_point = input->params.zero_point;
952         params.input_range_radius = data->input_range_radius;
953         params.input_multiplier = data->input_multiplier;
954         params.input_left_shift = data->input_left_shift;
955         optimized_ops::Tanh16bitPrecision(
956             params, GetTensorShape(input), GetTensorData<int8_t>(input),
957             GetTensorShape(output), GetTensorData<int8_t>(output));
958       } else {
959         EvalUsingLookupTable(data, input, output);
960       }
961       return kTfLiteOk;
962     } break;
963     default:
964       TF_LITE_KERNEL_LOG(context,
965                          "Only float32, uint8, int16 and int8 are supported "
966                          "currently, got %s.",
967                          TfLiteTypeGetName(input->type));
968       return kTfLiteError;
969   }
970 }
971 
972 // Sigmoid is also know as "Logistic".
973 template <KernelType kernel_type>
SigmoidEval(TfLiteContext * context,TfLiteNode * node)974 TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
975   OpData* data = reinterpret_cast<OpData*>(node->user_data);
976 
977   const TfLiteTensor* input;
978   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
979   TfLiteTensor* output;
980   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
981   switch (input->type) {
982     case kTfLiteFloat32: {
983       if (kernel_type == kReference) {
984         reference_ops::Logistic(
985             GetTensorShape(input), GetTensorData<float>(input),
986             GetTensorShape(output), GetTensorData<float>(output));
987       } else {
988         optimized_ops::Logistic(
989             GetTensorShape(input), GetTensorData<float>(input),
990             GetTensorShape(output), GetTensorData<float>(output));
991       }
992       break;
993     }
994     case kTfLiteInt16: {
995       LogisticParams params;
996       if (kernel_type == kReference || (data->input_multiplier > 0)) {
997         const int size =
998             MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
999 
1000         reference_integer_ops::Logistic(
1001             data->input_multiplier, data->input_left_shift, size,
1002             GetTensorData<int16_t>(input), GetTensorData<int16_t>(output));
1003       } else {
1004         optimized_ops::Logistic(
1005             params, GetTensorShape(input), GetTensorData<int16_t>(input),
1006             GetTensorShape(output), GetTensorData<int16_t>(output));
1007       }
1008       break;
1009     }
1010     case kTfLiteUInt8: {
1011       if (kernel_type == kFixedPointOptimized) {
1012         LogisticParams params;
1013         params.input_zero_point = input->params.zero_point;
1014         params.input_range_radius = data->input_range_radius;
1015         params.input_multiplier = data->input_multiplier;
1016         params.input_left_shift = data->input_left_shift;
1017         optimized_ops::Logistic16bitPrecision(
1018             params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1019             GetTensorShape(output), GetTensorData<uint8_t>(output));
1020       } else {
1021         EvalUsingLookupTable(data, input, output);
1022       }
1023       break;
1024     }
1025     case kTfLiteInt8: {
1026       if (kernel_type == kFixedPointOptimized) {
1027         LogisticParams params;
1028         params.input_zero_point = input->params.zero_point;
1029         params.input_range_radius = data->input_range_radius;
1030         params.input_multiplier = data->input_multiplier;
1031         params.input_left_shift = data->input_left_shift;
1032         optimized_ops::Logistic16bitPrecision(
1033             params, GetTensorShape(input), GetTensorData<int8_t>(input),
1034             GetTensorShape(output), GetTensorData<int8_t>(output));
1035       } else {
1036         EvalUsingLookupTable(data, input, output);
1037       }
1038       break;
1039     }
1040     default:
1041       TF_LITE_KERNEL_LOG(context,
1042                          "Only float32, uint8, int16 and int8 are supported "
1043                          "currently, got %s.",
1044                          TfLiteTypeGetName(input->type));
1045       return kTfLiteError;
1046   }
1047   return kTfLiteOk;
1048 }
1049 
SoftmaxFloat(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,TfLiteSoftmaxParams * params)1050 TfLiteStatus SoftmaxFloat(TfLiteContext* context, const TfLiteTensor* input,
1051                           TfLiteTensor* output, TfLiteSoftmaxParams* params) {
1052   SoftmaxParams op_params;
1053   op_params.beta = params->beta;
1054   optimized_ops::Softmax(op_params, GetTensorShape(input),
1055                          GetTensorData<float>(input), GetTensorShape(output),
1056                          GetTensorData<float>(output),
1057                          CpuBackendContext::GetFromContext(context));
1058   return kTfLiteOk;
1059 }
1060 
1061 template <typename In, typename Out>
SoftmaxQuantized(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,SoftmaxOpData * data)1062 TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input,
1063                               TfLiteTensor* output, SoftmaxOpData* data) {
1064   optimized_ops::Softmax(data->params, GetTensorShape(input),
1065                          GetTensorData<In>(input), GetTensorShape(output),
1066                          GetTensorData<Out>(output));
1067   return kTfLiteOk;
1068 }
1069 
1070 template <>
SoftmaxQuantized(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,SoftmaxOpData * data)1071 TfLiteStatus SoftmaxQuantized<int8_t, int8_t>(TfLiteContext* context,
1072                                               const TfLiteTensor* input,
1073                                               TfLiteTensor* output,
1074                                               SoftmaxOpData* data) {
1075 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
1076   optimized_ops::SoftmaxInt8LUT(
1077       data->params, GetTensorShape(input), GetTensorData<int8_t>(input),
1078       GetTensorShape(output), GetTensorData<int8_t>(output));
1079 #else
1080   optimized_ops::Softmax(data->params, GetTensorShape(input),
1081                          GetTensorData<int8_t>(input), GetTensorShape(output),
1082                          GetTensorData<int8_t>(output));
1083 #endif
1084   return kTfLiteOk;
1085 }
1086 
1087 template <>
SoftmaxQuantized(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,SoftmaxOpData * data)1088 TfLiteStatus SoftmaxQuantized<uint8_t, uint8_t>(TfLiteContext* context,
1089                                                 const TfLiteTensor* input,
1090                                                 TfLiteTensor* output,
1091                                                 SoftmaxOpData* data) {
1092 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
1093   optimized_ops::SoftmaxInt8LUT(
1094       data->params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1095       GetTensorShape(output), GetTensorData<uint8_t>(output));
1096 #else
1097   optimized_ops::Softmax(data->params, GetTensorShape(input),
1098                          GetTensorData<uint8_t>(input), GetTensorShape(output),
1099                          GetTensorData<uint8_t>(output));
1100 #endif
1101   return kTfLiteOk;
1102 }
1103 
1104 template <>
SoftmaxQuantized(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,SoftmaxOpData * data)1105 TfLiteStatus SoftmaxQuantized<int16, int16>(TfLiteContext* context,
1106                                             const TfLiteTensor* input,
1107                                             TfLiteTensor* output,
1108                                             SoftmaxOpData* data) {
1109   if (NumDimensions(input) >= 1 && NumDimensions(input) <= 4) {
1110     reference_ops::SoftmaxInt16(
1111         data->params, GetTensorShape(input), GetTensorData<int16_t>(input),
1112         GetTensorShape(output), GetTensorData<int16_t>(output));
1113     return kTfLiteOk;
1114   } else {
1115     TF_LITE_KERNEL_LOG(context,
1116                        "Only 1D, 2D, 3D and 4D tensors supported for int16 "
1117                        "input with int16 output, got %dD.",
1118                        NumDimensions(input));
1119     return kTfLiteError;
1120   }
1121 }
1122 
SoftmaxEval(TfLiteContext * context,TfLiteNode * node)1123 TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
1124   auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
1125   SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
1126 
1127   const TfLiteTensor* input;
1128   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1129   TfLiteTensor* output;
1130   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1131 
1132   switch (input->type) {
1133     case kTfLiteFloat32: {
1134       return SoftmaxFloat(context, input, output, params);
1135     }
1136     case kTfLiteUInt8: {
1137       switch (output->type) {
1138         case kTfLiteUInt8:
1139           return SoftmaxQuantized<uint8_t, uint8_t>(context, input, output,
1140                                                     data);
1141         case kTfLiteInt16:
1142           return SoftmaxQuantized<uint8_t, int16_t>(context, input, output,
1143                                                     data);
1144         default:
1145           TF_LITE_KERNEL_LOG(context,
1146                              "Only uint8_t and int16_t outputs are supported "
1147                              "with uint8_t inputs currently, got %s.",
1148                              TfLiteTypeGetName(output->type));
1149           return kTfLiteError;
1150       }
1151     }
1152     case kTfLiteInt8: {
1153       switch (output->type) {
1154         case kTfLiteInt8:
1155           return SoftmaxQuantized<int8_t, int8_t>(context, input, output, data);
1156         case kTfLiteInt16:
1157           return SoftmaxQuantized<int8_t, int16_t>(context, input, output,
1158                                                    data);
1159         default:
1160           TF_LITE_KERNEL_LOG(context,
1161                              "Only int8_t and int16_t outputs are supported "
1162                              "with int8_t inputs currently, got %s.",
1163                              TfLiteTypeGetName(output->type));
1164           return kTfLiteError;
1165       }
1166     }
1167     case kTfLiteInt16: {
1168       return SoftmaxQuantized<int16_t, int16_t>(context, input, output, data);
1169     }
1170 
1171     default:
1172       TF_LITE_KERNEL_LOG(context,
1173                          "Only float32, uint8_t, Int8_t, Int16_t are supported "
1174                          "currently, got %s.",
1175                          TfLiteTypeGetName(input->type));
1176       return kTfLiteError;
1177   }
1178 }
1179 
1180 template <KernelType kernel_type>
LogSoftmaxEval(TfLiteContext * context,TfLiteNode * node)1181 TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
1182   const LogSoftmaxOpData* data =
1183       reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
1184   const TfLiteTensor* input;
1185   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1186   TfLiteTensor* output;
1187   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1188   switch (input->type) {
1189     case kTfLiteFloat32: {
1190       SoftmaxParams op_params;
1191       if (kernel_type == kGenericOptimized) {
1192         optimized_ops::LogSoftmax(
1193             op_params, GetTensorShape(input), GetTensorData<float>(input),
1194             GetTensorShape(output), GetTensorData<float>(output));
1195       } else {
1196         reference_ops::LogSoftmax(
1197             op_params, GetTensorShape(input), GetTensorData<float>(input),
1198             GetTensorShape(output), GetTensorData<float>(output));
1199       }
1200       return kTfLiteOk;
1201     }
1202     case kTfLiteUInt8: {
1203       SoftmaxParams op_params = data->params;
1204       if (kernel_type == kGenericOptimized) {
1205         optimized_ops::LogSoftmax(
1206             op_params, input->params.scale, GetTensorShape(input),
1207             GetTensorData<uint8_t>(input), GetTensorShape(output),
1208             GetTensorData<uint8_t>(output));
1209       } else {
1210         reference_ops::LogSoftmax(
1211             op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1212             GetTensorShape(output), GetTensorData<uint8_t>(output));
1213       }
1214       return kTfLiteOk;
1215     }
1216     case kTfLiteInt8: {
1217       if (kernel_type == kGenericOptimized) {
1218         SoftmaxParams op_params = data->params;
1219         optimized_ops::LogSoftmax(
1220             op_params, input->params.scale, GetTensorShape(input),
1221             GetTensorData<int8_t>(input), GetTensorShape(output),
1222             GetTensorData<int8_t>(output));
1223       } else {
1224         const auto input_shape = GetTensorShape(input);
1225         const auto output_shape = GetTensorShape(output);
1226         const int trailing_dim = input_shape.DimensionsCount() - 1;
1227         const int outer_size =
1228             MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
1229         const int depth =
1230             MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
1231         reference_integer_ops::LogSoftmax(
1232             data->input_multiplier, data->input_left_shift,
1233             data->reverse_scaling_divisor, data->reverse_scaling_right_shift,
1234             data->diff_min, outer_size, depth, GetTensorData<int8_t>(input),
1235             GetTensorData<int8_t>(output));
1236       }
1237       return kTfLiteOk;
1238     }
1239     default:
1240       TF_LITE_KERNEL_LOG(
1241           context,
1242           "Only float32, uint8 and int8 are supported currently, got %s.",
1243           TfLiteTypeGetName(input->type));
1244       return kTfLiteError;
1245   }
1246 }
1247 
1248 template <typename T>
ApplyPrelu(T input,T alpha)1249 T ApplyPrelu(T input, T alpha) {
1250   return input >= 0.0 ? input : input * alpha;
1251 }
1252 
1253 template <KernelType kernel_type>
PreluEval(TfLiteContext * context,TfLiteNode * node)1254 TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
1255   const TfLiteTensor* input;
1256   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1257   const TfLiteTensor* alpha;
1258   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha));
1259   TfLiteTensor* output;
1260   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1261   const PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
1262   switch (input->type) {
1263     case kTfLiteFloat32: {
1264       if (kernel_type == kGenericOptimized) {
1265         tflite::ArithmeticParams op_params;
1266         bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
1267             GetTensorShape(input), GetTensorShape(alpha), &op_params);
1268         if (need_broadcast) {
1269           optimized_ops::BroadcastPReluDispatch(
1270               op_params, GetTensorShape(input), GetTensorData<float>(input),
1271               GetTensorShape(alpha), GetTensorData<float>(alpha),
1272               GetTensorShape(output), GetTensorData<float>(output),
1273               ApplyPrelu<float>);
1274         } else {
1275           const int flat_size =
1276               MatchingElementsSize(GetTensorShape(input), GetTensorShape(alpha),
1277                                    GetTensorShape(output));
1278           optimized_ops::PReluElementWise(
1279               flat_size, op_params, GetTensorData<float>(alpha),
1280               GetTensorData<float>(input), GetTensorData<float>(output));
1281         }
1282       } else {
1283         if (data->requires_broadcast) {
1284           reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
1285               GetTensorShape(input), GetTensorData<float>(input),
1286               GetTensorShape(alpha), GetTensorData<float>(alpha),
1287               GetTensorShape(output), GetTensorData<float>(output),
1288               ApplyPrelu<float>);
1289         } else {
1290           reference_ops::BinaryFunction<float, float, float>(
1291               GetTensorShape(input), GetTensorData<float>(input),
1292               GetTensorShape(alpha), GetTensorData<float>(alpha),
1293               GetTensorShape(output), GetTensorData<float>(output),
1294               ApplyPrelu<float>);
1295         }
1296       }
1297       return kTfLiteOk;
1298     } break;
1299     case kTfLiteUInt8: {
1300       PreluParams op_params;
1301       op_params.input_offset = -input->params.zero_point;
1302       op_params.alpha_offset = -alpha->params.zero_point;
1303       op_params.output_offset = output->params.zero_point;
1304       op_params.output_multiplier_1 = data->output_multiplier_1;
1305       op_params.output_shift_1 = data->output_shift_1;
1306       op_params.output_multiplier_2 = data->output_multiplier_2;
1307       op_params.output_shift_2 = data->output_shift_2;
1308       if (data->requires_broadcast) {
1309         reference_ops::BroadcastPrelu4DSlow(
1310             op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1311             GetTensorShape(alpha), GetTensorData<uint8_t>(alpha),
1312             GetTensorShape(output), GetTensorData<uint8_t>(output));
1313       } else {
1314         reference_ops::Prelu(
1315             op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1316             GetTensorShape(alpha), GetTensorData<uint8_t>(alpha),
1317             GetTensorShape(output), GetTensorData<uint8_t>(output));
1318       }
1319       return kTfLiteOk;
1320     } break;
1321     case kTfLiteInt8: {
1322       PreluParams op_params;
1323       op_params.input_offset = -input->params.zero_point;
1324       op_params.alpha_offset = -alpha->params.zero_point;
1325       op_params.output_offset = output->params.zero_point;
1326       op_params.output_multiplier_1 = data->output_multiplier_1;
1327       op_params.output_shift_1 = data->output_shift_1;
1328       op_params.output_multiplier_2 = data->output_multiplier_2;
1329       op_params.output_shift_2 = data->output_shift_2;
1330       if (data->requires_broadcast) {
1331         reference_ops::BroadcastPrelu4DSlow(
1332             op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
1333             GetTensorShape(alpha), GetTensorData<int8_t>(alpha),
1334             GetTensorShape(output), GetTensorData<int8_t>(output));
1335       } else {
1336         reference_ops::Prelu(
1337             op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
1338             GetTensorShape(alpha), GetTensorData<int8_t>(alpha),
1339             GetTensorShape(output), GetTensorData<int8_t>(output));
1340       }
1341       return kTfLiteOk;
1342     } break;
1343     default:
1344       TF_LITE_KERNEL_LOG(
1345           context,
1346           "Only float32 and uint8 and int8 are supported currently, got %d.",
1347           TfLiteTypeGetName(input->type));
1348       return kTfLiteError;
1349   }
1350 }
1351 
1352 template <typename T>
QuantizeLeakyRelu(const TfLiteTensor * input,TfLiteTensor * output,const LeakyReluOpData * data)1353 void QuantizeLeakyRelu(const TfLiteTensor* input, TfLiteTensor* output,
1354                        const LeakyReluOpData* data) {
1355   LeakyReluParams op_params;
1356 
1357   op_params.input_offset = input->params.zero_point;
1358   op_params.output_offset = output->params.zero_point;
1359   op_params.output_multiplier_alpha = data->output_multiplier_alpha;
1360   op_params.output_shift_alpha = data->output_shift_alpha;
1361   op_params.output_multiplier_identity = data->output_multiplier_identity;
1362   op_params.output_shift_identity = data->output_shift_identity;
1363   reference_ops::QuantizeLeakyRelu(
1364       op_params, GetTensorShape(input), GetTensorData<T>(input),
1365       GetTensorShape(output), GetTensorData<T>(output));
1366 }
1367 
LeakyReluEval(TfLiteContext * context,TfLiteNode * node)1368 TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
1369   const TfLiteTensor* input;
1370   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1371   TfLiteTensor* output;
1372   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1373   const auto* params =
1374       reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data);
1375   const LeakyReluOpData* data =
1376       reinterpret_cast<LeakyReluOpData*>(node->user_data);
1377 
1378   LeakyReluParams op_params;
1379   switch (input->type) {
1380     case kTfLiteFloat32: {
1381       op_params.alpha = params->alpha;
1382       optimized_ops::LeakyRelu(
1383           op_params, GetTensorShape(input), GetTensorData<float>(input),
1384           GetTensorShape(output), GetTensorData<float>(output));
1385       return kTfLiteOk;
1386     } break;
1387     case kTfLiteUInt8: {
1388       QuantizeLeakyRelu<uint8_t>(input, output, data);
1389       return kTfLiteOk;
1390     } break;
1391     case kTfLiteInt8: {
1392       QuantizeLeakyRelu<int8_t>(input, output, data);
1393       return kTfLiteOk;
1394     } break;
1395     case kTfLiteInt16: {
1396       QuantizeLeakyRelu<int16_t>(input, output, data);
1397       return kTfLiteOk;
1398     } break;
1399     default:
1400       TF_LITE_KERNEL_LOG(
1401           context,
1402           "Only float32, int8, int16 and uint8 is supported currently, got %s.",
1403           TfLiteTypeGetName(input->type));
1404       return kTfLiteError;
1405   }
1406 }
1407 
EluPrepare(TfLiteContext * context,TfLiteNode * node)1408 TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
1409   const TfLiteTensor* input;
1410   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1411   TfLiteTensor* output;
1412   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1413   OpData* data = reinterpret_cast<OpData*>(node->user_data);
1414 
1415   // Use LUT to handle quantized elu path.
1416   if (input->type == kTfLiteInt8) {
1417     PopulateLookupTable<int8_t>(data, input, output, [](float value) {
1418       return value < 0.0 ? std::exp(value) - 1.0f : value;
1419     });
1420   }
1421   return GenericPrepare(context, node);
1422 }
1423 
EluEval(TfLiteContext * context,TfLiteNode * node)1424 TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
1425   const TfLiteTensor* input;
1426   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
1427   TfLiteTensor* output;
1428   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
1429   switch (input->type) {
1430     case kTfLiteFloat32: {
1431       optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),
1432                          GetTensorShape(output), GetTensorData<float>(output));
1433       return kTfLiteOk;
1434     } break;
1435     case kTfLiteInt8: {
1436       OpData* data = reinterpret_cast<OpData*>(node->user_data);
1437       EvalUsingLookupTable(data, input, output);
1438       return kTfLiteOk;
1439     } break;
1440     default:
1441       TF_LITE_KERNEL_LOG(
1442           context, "Only float32 and int8 is supported currently, got %s.",
1443           TfLiteTypeGetName(input->type));
1444       return kTfLiteError;
1445   }
1446 }
1447 
1448 }  // namespace activations
1449 
Register_ELU()1450 TfLiteRegistration* Register_ELU() {
1451   static TfLiteRegistration r = {activations::Init, activations::Free,
1452                                  activations::EluPrepare, activations::EluEval};
1453   return &r;
1454 }
1455 
Register_RELU()1456 TfLiteRegistration* Register_RELU() {
1457   static TfLiteRegistration r = {activations::ReluInit, activations::ReluFree,
1458                                  activations::ReluPrepare,
1459                                  activations::ReluEval};
1460   return &r;
1461 }
1462 
Register_RELU_N1_TO_1()1463 TfLiteRegistration* Register_RELU_N1_TO_1() {
1464   static TfLiteRegistration r = {activations::ReluInit, activations::ReluFree,
1465                                  activations::ReluPrepare,
1466                                  activations::Relu1Eval};
1467   return &r;
1468 }
1469 
Register_RELU6()1470 TfLiteRegistration* Register_RELU6() {
1471   static TfLiteRegistration r = {activations::ReluInit, activations::ReluFree,
1472                                  activations::ReluPrepare,
1473                                  activations::Relu6Eval};
1474   return &r;
1475 }
1476 
Register_TANH_REF()1477 TfLiteRegistration* Register_TANH_REF() {
1478   static TfLiteRegistration r = {
1479       activations::Init, activations::Free,
1480       activations::TanhPrepare<activations::kReference>,
1481       activations::TanhEval<activations::kReference>};
1482   return &r;
1483 }
1484 
Register_TANH_GENERIC_OPT()1485 TfLiteRegistration* Register_TANH_GENERIC_OPT() {
1486   static TfLiteRegistration r = {
1487       activations::Init, activations::Free,
1488       activations::TanhPrepare<activations::kGenericOptimized>,
1489       activations::TanhEval<activations::kGenericOptimized>};
1490   return &r;
1491 }
1492 
Register_TANH_FIXED_POINT_OPT()1493 TfLiteRegistration* Register_TANH_FIXED_POINT_OPT() {
1494   static TfLiteRegistration r = {
1495       activations::Init, activations::Free,
1496       activations::TanhPrepare<activations::kFixedPointOptimized>,
1497       activations::TanhEval<activations::kFixedPointOptimized>};
1498   return &r;
1499 }
1500 
Register_TANH()1501 TfLiteRegistration* Register_TANH() {
1502   // TODO(b/134622898): Switch over from the LUT optimized method to the fixed
1503   // point optimized method when typical Android hardware performs better on
1504   // the latter one.
1505   return Register_TANH_GENERIC_OPT();
1506 }
1507 
Register_LOGISTIC_REF()1508 TfLiteRegistration* Register_LOGISTIC_REF() {
1509   static TfLiteRegistration r = {
1510       activations::Init, activations::Free,
1511       activations::SigmoidPrepare<activations::kReference>,
1512       activations::SigmoidEval<activations::kReference>};
1513   return &r;
1514 }
1515 
Register_LOGISTIC_GENERIC_OPT()1516 TfLiteRegistration* Register_LOGISTIC_GENERIC_OPT() {
1517   static TfLiteRegistration r = {
1518       activations::Init, activations::Free,
1519       activations::SigmoidPrepare<activations::kGenericOptimized>,
1520       activations::SigmoidEval<activations::kGenericOptimized>};
1521   return &r;
1522 }
1523 
Register_LOGISTIC_FIXED_POINT_OPT()1524 TfLiteRegistration* Register_LOGISTIC_FIXED_POINT_OPT() {
1525   static TfLiteRegistration r = {
1526       activations::Init, activations::Free,
1527       activations::SigmoidPrepare<activations::kFixedPointOptimized>,
1528       activations::SigmoidEval<activations::kFixedPointOptimized>};
1529   return &r;
1530 }
1531 
Register_LOGISTIC()1532 TfLiteRegistration* Register_LOGISTIC() {
1533   // TODO(b/134622898): Switch over from the LUT optimized method to the fixed
1534   // point optimized method when typical Android hardware performs better on
1535   // the latter one.
1536   return Register_LOGISTIC_GENERIC_OPT();
1537 }
1538 
Register_SOFTMAX()1539 TfLiteRegistration* Register_SOFTMAX() {
1540   static TfLiteRegistration r = {
1541       activations::SoftmaxInit, activations::SoftmaxFree,
1542       activations::SoftmaxPrepare, activations::SoftmaxEval};
1543   return &r;
1544 }
1545 
Register_LOG_SOFTMAX_REF()1546 TfLiteRegistration* Register_LOG_SOFTMAX_REF() {
1547   static TfLiteRegistration r = {
1548       activations::LogSoftmaxInit, activations::LogSoftmaxFree,
1549       activations::LogSoftmaxPrepare,
1550       activations::LogSoftmaxEval<activations::kReference>};
1551   return &r;
1552 }
1553 
Register_LOG_SOFTMAX()1554 TfLiteRegistration* Register_LOG_SOFTMAX() {
1555   static TfLiteRegistration r = {
1556       activations::LogSoftmaxInit, activations::LogSoftmaxFree,
1557       activations::LogSoftmaxPrepare,
1558       activations::LogSoftmaxEval<activations::kGenericOptimized>};
1559   return &r;
1560 }
1561 
Register_PRELU_REF()1562 TfLiteRegistration* Register_PRELU_REF() {
1563   static TfLiteRegistration r = {
1564       activations::PreluInit, activations::PreluFree, activations::PreluPrepare,
1565       activations::PreluEval<activations::kReference>};
1566   return &r;
1567 }
1568 
Register_PRELU()1569 TfLiteRegistration* Register_PRELU() {
1570   static TfLiteRegistration r = {
1571       activations::PreluInit, activations::PreluFree, activations::PreluPrepare,
1572       activations::PreluEval<activations::kGenericOptimized>};
1573   return &r;
1574 }
1575 
Register_LEAKY_RELU()1576 TfLiteRegistration* Register_LEAKY_RELU() {
1577   static TfLiteRegistration r = {
1578       activations::LeakyReluInit, activations::LeakyReluFree,
1579       activations::LeakyReluPrepare, activations::LeakyReluEval};
1580   return &r;
1581 }
1582 
Register_HARD_SWISH()1583 TfLiteRegistration* Register_HARD_SWISH() {
1584   static TfLiteRegistration r = {
1585       activations::HardSwishInit, activations::HardSwishFree,
1586       activations::HardSwishPrepare,
1587       activations::HardSwishEval<activations::kGenericOptimized>};
1588   return &r;
1589 }
1590 
Register_HARD_SWISH_REF()1591 TfLiteRegistration* Register_HARD_SWISH_REF() {
1592   static TfLiteRegistration r = {
1593       activations::HardSwishInit, activations::HardSwishFree,
1594       activations::HardSwishPrepare,
1595       activations::HardSwishEval<activations::kReference>};
1596   return &r;
1597 }
1598 
1599 }  // namespace builtin
1600 }  // namespace ops
1601 }  // namespace tflite
1602